Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
2a67dc33
Unverified
Commit
2a67dc33
authored
Sep 26, 2022
by
oahzxl
Committed by
GitHub
Sep 26, 2022
Browse files
Use inplace to save memory (#72)
Use inplace to save memory
parent
efa8a9e4
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
704 additions
and
39 deletions
+704
-39
README.md
README.md
+21
-0
fastfold/model/fastnn/blocks.py
fastfold/model/fastnn/blocks.py
+172
-0
fastfold/model/fastnn/msa.py
fastfold/model/fastnn/msa.py
+12
-0
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+248
-0
fastfold/model/fastnn/triangle.py
fastfold/model/fastnn/triangle.py
+18
-0
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+44
-17
fastfold/model/nn/embedders.py
fastfold/model/nn/embedders.py
+23
-15
fastfold/model/nn/evoformer.py
fastfold/model/nn/evoformer.py
+100
-0
fastfold/model/nn/template.py
fastfold/model/nn/template.py
+59
-4
inference.py
inference.py
+2
-0
inference.sh
inference.sh
+5
-3
No files found.
README.md
View file @
2a67dc33
...
...
@@ -143,6 +143,27 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--enable_workflow
```
#### inference with lower memory usage
Alphafold's embedding presentations take up a lot of memory as the sequence length increases. To reduce memory usage,
you should add parameter
`--chunk_size [N]`
and
`--inplace`
to cmdline or shell script
`./inference.sh`
.
The smaller you set N, the less memory will be used, but it will affect the speed. We can inference
a sequence of length 7000 in fp32 on a 80G A100.
```
shell
python inference.py target.fasta data/pdb_mmcif/mmcif_files/
\
--output_dir
./
\
--gpus
2
\
--uniref90_database_path
data/uniref90/uniref90.fasta
\
--mgnify_database_path
data/mgnify/mgy_clusters_2018_12.fa
\
--pdb70_database_path
data/pdb70/pdb70
\
--uniclust30_database_path
data/uniclust30/uniclust30_2018_08/uniclust30_2018_08
\
--bfd_database_path
data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt
\
--jackhmmer_binary_path
`
which jackhmmer
`
\
--hhblits_binary_path
`
which hhblits
`
\
--hhsearch_binary_path
`
which hhsearch
`
\
--kalign_binary_path
`
which kalign
`
\
--chunk_size
N
\
--inplace
```
## Performance Benchmark
...
...
fastfold/model/fastnn/blocks.py
View file @
2a67dc33
...
...
@@ -99,6 +99,63 @@ class EvoformerBlock(nn.Module):
return
m
,
z
def
inplace
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
seq_length
=
pair_mask
.
size
(
-
1
)
padding_size
=
(
int
(
seq_length
/
dap_size
)
+
1
)
*
dap_size
-
seq_length
if
self
.
first_block
:
m
[
0
]
=
m
[
0
].
unsqueeze
(
0
)
z
[
0
]
=
z
[
0
].
unsqueeze
(
0
)
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
],
(
0
,
0
,
0
,
padding_size
))
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
msa_mask
=
torch
.
nn
.
functional
.
pad
(
msa_mask
,
(
0
,
padding_size
))
pair_mask
=
torch
.
nn
.
functional
.
pad
(
pair_mask
,
(
0
,
padding_size
,
0
,
padding_size
))
if
not
self
.
is_multimer
:
m
[
0
]
=
self
.
msa_stack
(
m
[
0
],
z
[
0
],
msa_mask
)
z
=
self
.
communication
.
inplace
(
m
[
0
],
msa_mask
,
z
)
m
[
0
],
work
=
All_to_All_Async
.
apply
(
m
[
0
],
1
,
2
)
z
=
self
.
pair_stack
.
inplace
(
z
,
pair_mask
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
else
:
z
=
self
.
communication
(
m
,
msa_mask
,
z
)
z_ori
=
z
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
m
=
self
.
msa_stack
(
m
,
z_ori
,
msa_mask
)
if
self
.
last_block
:
m
[
0
]
=
m
[
0
].
squeeze
(
0
)
z
[
0
]
=
z
[
0
].
squeeze
(
0
)
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
)
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:
-
padding_size
,
:
-
padding_size
,
:]
return
m
,
z
class
ExtraMSABlock
(
nn
.
Module
):
def
__init__
(
...
...
@@ -183,6 +240,74 @@ class ExtraMSABlock(nn.Module):
return
m
,
z
def
inplace
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
seq_cnt
=
msa_mask
.
size
(
-
2
)
seq_len
=
pair_mask
.
size
(
-
1
)
seq_cnt_padding_size
=
(
int
(
seq_cnt
/
dap_size
)
+
1
)
*
dap_size
-
seq_cnt
seq_len_padding_size
=
(
int
(
seq_len
/
dap_size
)
+
1
)
*
dap_size
-
seq_len
if
self
.
first_block
:
m
[
0
]
=
m
[
0
].
unsqueeze
(
0
)
z
[
0
]
=
z
[
0
].
unsqueeze
(
0
)
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
],
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
)
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
)
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
)
if
not
self
.
is_multimer
else
scatter
(
m
[
0
],
dim
=
2
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
msa_mask
=
torch
.
nn
.
functional
.
pad
(
msa_mask
,
(
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
)
pair_mask
=
torch
.
nn
.
functional
.
pad
(
pair_mask
,
(
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
)
if
not
self
.
is_multimer
:
m
=
self
.
msa_stack
.
inplace
(
m
,
z
,
msa_mask
)
z
=
self
.
communication
.
inplace
(
m
[
0
],
msa_mask
,
z
)
m
[
0
],
work
=
All_to_All_Async
.
apply
(
m
[
0
],
1
,
2
)
z
=
self
.
pair_stack
.
inplace
(
z
,
pair_mask
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
else
:
z
=
self
.
communication
(
m
,
msa_mask
,
z
)
z_ori
=
z
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
m
=
self
.
msa_stack
(
m
,
z_ori
,
msa_mask
)
if
self
.
last_block
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
)
if
not
self
.
is_multimer
else
gather
(
m
[
0
],
dim
=
2
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
)
m
[
0
]
=
m
[
0
][:,
:
-
seq_cnt_padding_size
,
:
-
seq_len_padding_size
,
:]
z
[
0
]
=
z
[
0
][:,
:
-
seq_len_padding_size
,
:
-
seq_len_padding_size
,
:]
m
[
0
]
=
m
[
0
].
squeeze
(
0
)
z
[
0
]
=
z
[
0
].
squeeze
(
0
)
return
m
,
z
class
TemplatePairStackBlock
(
nn
.
Module
):
def
__init__
(
...
...
@@ -269,3 +394,50 @@ class TemplatePairStackBlock(nn.Module):
z
=
z
[:,
:
-
padding_size
,
:
-
padding_size
,
:]
return
z
def
inplace
(
self
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
):
if
isinstance
(
chunk_size
,
int
)
and
1
<=
chunk_size
<=
4
:
z
[
0
]
=
z
[
0
].
cpu
()
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
seq_length
=
mask
.
size
(
-
1
)
padding_size
=
(
int
(
seq_length
/
dap_size
)
+
1
)
*
dap_size
-
seq_length
if
self
.
first_block
:
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
# single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)]
# single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
for
i
in
range
(
z
[
0
].
shape
[
0
]):
single
=
z
[
0
][
i
].
unsqueeze
(
-
4
).
to
(
mask
.
device
)
single_mask
=
mask
[
i
].
unsqueeze
(
-
3
)
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
)
single
=
self
.
TriangleMultiplicationOutgoing
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
self
.
TriangleMultiplicationIncoming
(
single
,
single_mask_col
)
single
=
col_to_row
(
single
)
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
self
.
TriangleAttentionEndingNode
(
single
,
single_mask_col
)
single
=
self
.
PairTransition
(
single
)
single
=
col_to_row
(
single
)
z
[
0
][
i
]
=
single
.
to
(
z
[
0
].
device
)
# z = torch.cat(single_templates, dim=-4)
if
self
.
last_block
:
z
[
0
]
=
gather
(
z
[
0
],
dim
=
1
)
z
[
0
]
=
z
[
0
][:,
:
-
padding_size
,
:
-
padding_size
,
:]
return
z
\ No newline at end of file
fastfold/model/fastnn/msa.py
View file @
2a67dc33
...
...
@@ -169,3 +169,15 @@ class ExtraMSAStack(nn.Module):
node
=
self
.
MSATransition
(
node
)
return
node
def
inplace
(
self
,
node
,
pair
,
node_mask
):
node_mask_row
=
scatter
(
node_mask
,
dim
=
1
)
node
=
self
.
MSARowAttentionWithPairBias
.
inplace
(
node
,
pair
,
node_mask_row
)
node
[
0
]
=
row_to_col
(
node
[
0
])
node_mask_col
=
scatter
(
node_mask
,
dim
=
2
)
node
=
self
.
MSAColumnAttention
.
inplace
(
node
,
node_mask_col
)
node
=
self
.
MSATransition
.
inplace
(
node
)
return
node
\ No newline at end of file
fastfold/model/fastnn/ops.py
View file @
2a67dc33
...
...
@@ -29,6 +29,7 @@ from fastfold.distributed.comm_async import gather_async, gather_async_opp, get_
CHUNK_SIZE
=
None
DEBUG
=
False
def
set_chunk_size
(
chunk_size
):
...
...
@@ -94,15 +95,34 @@ class ChunkTransition(nn.Module):
chunk_size
=
48
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
else
:
chunk_size
=
CHUNK_SIZE
*
48
out
=
torch
.
empty_like
(
src
)
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
ax
>
10
:
break
x
=
self
.
norm
(
src
[:,
ax
:
ax
+
chunk_size
,
:,
:])
x
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
x
)))
out
[:,
ax
:
ax
+
chunk_size
,
:,
:]
=
x
out
.
add_
(
src
)
return
out
def
inplace
(
self
,
src
):
para_dim
=
src
[
0
].
shape
[
1
]
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
else
:
chunk_size
=
CHUNK_SIZE
*
48
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
ax
>
10
:
break
x
=
self
.
norm
(
src
[
0
][:,
ax
:
ax
+
chunk_size
,
:,
:])
x
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
x
)))
src
[
0
][:,
ax
:
ax
+
chunk_size
,
:,
:]
+=
x
return
src
class
OutProductMean
(
nn
.
Module
):
...
...
@@ -117,6 +137,7 @@ class OutProductMean(nn.Module):
n_feat_out
,
initializer
=
'zero'
,
use_bias
=
True
)
self
.
n_feat_proj
=
n_feat_proj
def
forward
(
self
,
M
,
M_mask
,
Z_raw
):
M
=
self
.
layernormM
(
M
)
...
...
@@ -148,6 +169,59 @@ class OutProductMean(nn.Module):
return
Z_raw
def
inplace
(
self
,
M
,
M_mask
,
Z_raw
):
chunk_size
=
CHUNK_SIZE
if
len
(
M
.
shape
)
==
4
:
para_dim
=
M
.
shape
[
1
]
left_act
=
torch
.
empty
((
M
.
shape
[
0
],
M
.
shape
[
1
],
M
.
shape
[
2
],
self
.
n_feat_proj
),
dtype
=
M
.
dtype
,
device
=
M
.
device
)
right_act
=
torch
.
empty
((
M
.
shape
[
0
],
M
.
shape
[
1
],
M
.
shape
[
2
],
self
.
n_feat_proj
),
dtype
=
M
.
dtype
,
device
=
M
.
device
)
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
else
:
chunk_size
=
chunk_size
*
32
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
m
=
self
.
layernormM
(
M
[:,
ax
:
ax
+
chunk_size
,
:,
:])
right_act
[:,
ax
:
ax
+
chunk_size
,
:,
:]
=
self
.
linear_b
(
m
)
left_act
[:,
ax
:
ax
+
chunk_size
,
:,
:]
=
self
.
linear_a
(
m
)
else
:
para_dim
=
M
.
shape
[
0
]
left_act
=
torch
.
empty
((
M
.
shape
[
0
],
M
.
shape
[
1
],
self
.
n_feat_proj
),
dtype
=
M
.
dtype
,
device
=
M
.
device
)
right_act
=
torch
.
empty
((
M
.
shape
[
0
],
M
.
shape
[
1
],
self
.
n_feat_proj
),
dtype
=
M
.
dtype
,
device
=
M
.
device
)
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
else
:
chunk_size
=
chunk_size
*
32
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
m
=
self
.
layernormM
(
M
[
ax
:
ax
+
chunk_size
,
:,
:])
right_act
[
ax
:
ax
+
chunk_size
,
:,
:]
=
self
.
linear_b
(
m
)
left_act
[
ax
:
ax
+
chunk_size
,
:,
:]
=
self
.
linear_a
(
m
)
right_act_all
,
work
=
gather_async
(
right_act
,
dim
=
2
)
# right_act_all = gather(right_act, dim=2)
M_mask
=
M_mask
.
unsqueeze
(
-
1
)
M_mask_col
=
scatter
(
M_mask
,
dim
=
2
)
left_act
=
M_mask_col
*
left_act
norm
=
torch
.
einsum
(
'bsid,bsjd->bijd'
,
M_mask_col
,
M_mask
)
+
1e-3
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
right_act_all
=
M_mask
*
right_act_all
para_dim
=
left_act
.
shape
[
2
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
left_act_part
=
left_act
[:,
:,
ax
:
ax
+
chunk_size
,
:]
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act_part
,
right_act_all
)
O
=
rearrange
(
O
,
'b i j d e -> b i j (d e)'
)
O
=
self
.
o_linear
(
O
)
norm0
=
norm
[:,
ax
:
ax
+
chunk_size
,
:,
:]
Z_raw
[
0
][:,
ax
:
ax
+
chunk_size
,
:,
:]
+=
O
/
norm0
return
Z_raw
class
Linear
(
nn
.
Linear
):
"""
...
...
@@ -316,6 +390,8 @@ class AsyncChunkTriangleMultiplicationOutgoing(nn.Module):
output
=
torch
.
empty_like
(
Z_raw
)
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
zi
=
Z_raw
[:,
i
:
i
+
chunk_size
,
:,
:]
zi
=
self
.
layernorm1
(
zi
)
gi
=
torch
.
sigmoid
(
self
.
left_right_gate
(
zi
))
...
...
@@ -443,6 +519,8 @@ class AsyncChunkTriangleMultiplicationIncoming(nn.Module):
output
=
torch
.
empty_like
(
Z_raw
)
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
zi
=
Z_raw
[:,
:,
i
:
i
+
chunk_size
,
:]
zi
=
self
.
layernorm1
(
zi
)
gi
=
torch
.
sigmoid
(
self
.
left_right_gate
(
zi
))
...
...
@@ -577,6 +655,8 @@ class ChunkTriangleAttentionStartingNode(nn.Module):
output
=
torch
.
empty_like
(
Z_raw
)
dropout_mask
=
torch
.
ones_like
(
z
[:,
0
:
1
,
:,
:],
device
=
z
.
device
,
dtype
=
z
.
dtype
)
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
z_raw
=
Z_raw
[:,
i
:
i
+
chunk_size
,
:,
:]
z
=
self
.
layernorm1
(
z_raw
)
z_mask
=
Z_mask
[:,
i
:
i
+
chunk_size
,
:]
...
...
@@ -592,6 +672,52 @@ class ChunkTriangleAttentionStartingNode(nn.Module):
return
output
def
inplace
(
self
,
Z_raw
,
Z_mask
):
if
CHUNK_SIZE
==
None
:
Z
=
self
.
layernorm1
(
Z_raw
)
b
=
self
.
linear_b
(
Z
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
Z
=
self
.
attention
(
Z
,
Z_mask
,
(
b
,
work
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
chunk_size
=
CHUNK_SIZE
para_dim
=
Z_raw
[
0
].
shape
[
1
]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b
=
torch
.
empty
((
Z_raw
[
0
].
shape
[
0
],
Z_raw
[
0
].
shape
[
1
],
Z_raw
[
0
].
shape
[
2
],
self
.
n_head
),
device
=
Z_raw
[
0
].
device
,
dtype
=
Z_raw
[
0
].
dtype
)
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
z
=
self
.
layernorm1
(
Z_raw
[
0
][:,
i
:
i
+
chunk_size
,
:,
:])
b
[:,
i
:
i
+
chunk_size
,
:,
:]
=
self
.
linear_b
(
z
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
b
=
gather_async_opp
(
b
,
work
,
dim
=
1
)
b
=
rearrange
(
b
,
'b q k h -> b h q k'
)
# output = torch.empty_like(Z_raw)
dropout_mask
=
torch
.
ones_like
(
z
[:,
0
:
1
,
:,
:],
device
=
z
.
device
,
dtype
=
z
.
dtype
)
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
z_raw
=
Z_raw
[
0
][:,
i
:
i
+
chunk_size
,
:,
:]
z
=
self
.
layernorm1
(
z_raw
)
z_mask
=
Z_mask
[:,
i
:
i
+
chunk_size
,
:]
z
=
self
.
attention
(
z
,
z_mask
,
(
b
,
-
1
))
z
=
bias_dropout_add
(
z
,
self
.
out_bias
,
dropout_mask
,
z_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
Z_raw
[
0
][:,
i
:
i
+
chunk_size
,
:,
:]
=
z
return
Z_raw
class
ChunkMSARowAttentionWithPairBias
(
nn
.
Module
):
...
...
@@ -648,6 +774,8 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
output
=
torch
.
empty_like
(
M_raw
)
dropout_mask
=
torch
.
ones_like
(
M_raw
[:,
0
:
1
,
:,
:],
device
=
M_raw
.
device
,
dtype
=
M_raw
.
dtype
)
for
i
in
range
(
0
,
para_dim_m
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
m_raw
=
M_raw
[:,
i
:
i
+
chunk_size
,
:,
:]
m
=
self
.
layernormM
(
m_raw
)
m_mask
=
M_mask
[:,
i
:
i
+
chunk_size
,
:]
...
...
@@ -663,6 +791,51 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
return
output
def
inplace
(
self
,
M_raw
,
Z
,
M_mask
):
if
CHUNK_SIZE
==
None
:
## Input projections
M
=
self
.
layernormM
(
M_raw
)
Z
=
self
.
layernormZ
(
Z
)
b
=
F
.
linear
(
Z
,
self
.
linear_b_weights
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
# b = rearrange(b, 'b q k h -> b h q k')
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M
=
self
.
attention
(
M
,
M_mask
,
(
b
,
work
))
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:],
device
=
M
.
device
,
dtype
=
M
.
dtype
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
chunk_size
=
CHUNK_SIZE
para_dim_z
=
Z
[
0
].
shape
[
1
]
para_dim_m
=
M_raw
[
0
].
shape
[
1
]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b
=
torch
.
empty
((
Z
[
0
].
shape
[
0
],
Z
[
0
].
shape
[
1
],
Z
[
0
].
shape
[
2
],
self
.
n_head
),
device
=
Z
[
0
].
device
,
dtype
=
Z
[
0
].
dtype
)
for
i
in
range
(
0
,
para_dim_z
,
chunk_size
):
z
=
self
.
layernormZ
(
Z
[
0
][:,
i
:
i
+
chunk_size
,
:,
:])
b
[:,
i
:
i
+
chunk_size
,
:,
:]
=
F
.
linear
(
z
,
self
.
linear_b_weights
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
b
=
gather_async_opp
(
b
,
work
,
dim
=
1
)
b
=
rearrange
(
b
,
'b q k h -> b h q k'
)
dropout_mask
=
torch
.
ones_like
(
M_raw
[
0
][:,
0
:
1
,
:,
:],
device
=
M_raw
[
0
].
device
,
dtype
=
M_raw
[
0
].
dtype
)
for
i
in
range
(
0
,
para_dim_m
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
m_raw
=
M_raw
[
0
][:,
i
:
i
+
chunk_size
,
:,
:]
m
=
self
.
layernormM
(
m_raw
)
m_mask
=
M_mask
[:,
i
:
i
+
chunk_size
,
:]
m
=
self
.
attention
(
m
,
m_mask
,
(
b
,
-
1
))
m
=
bias_dropout_add
(
m
,
self
.
out_bias
,
dropout_mask
,
m_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
M_raw
[
0
][:,
i
:
i
+
chunk_size
,
:,
:]
=
m
return
M_raw
class
ChunkTriangleAttentionEndingNode
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
32
,
n_head
=
4
):
...
...
@@ -716,6 +889,8 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
output
=
torch
.
empty_like
(
Z_raw
)
dropout_mask
=
torch
.
ones_like
(
Z_raw
[:,
:,
0
:
1
,
:],
device
=
z
.
device
,
dtype
=
z
.
dtype
)
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
z_raw
=
Z_raw
[:,
:,
i
:
i
+
chunk_size
,
:]
z
=
self
.
layernorm1
(
z_raw
.
transpose
(
-
2
,
-
3
))
z_mask
=
Z_mask
[:,
:,
i
:
i
+
chunk_size
].
transpose
(
-
1
,
-
2
)
...
...
@@ -732,6 +907,57 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
return
output
def
inplace
(
self
,
Z_raw
,
Z_mask
):
if
CHUNK_SIZE
==
None
:
Z
=
Z_raw
.
transpose
(
-
2
,
-
3
)
Z_mask
=
Z_mask
.
transpose
(
-
1
,
-
2
)
Z
=
self
.
layernorm1
(
Z
)
b
=
self
.
linear_b
(
Z
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
Z
=
self
.
attention
(
Z
,
Z_mask
,
(
b
,
work
))
Z
=
Z
.
transpose
(
-
2
,
-
3
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
para_dim
=
Z_raw
[
0
].
shape
[
2
]
chunk_size
=
CHUNK_SIZE
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b
=
torch
.
empty
((
Z_raw
[
0
].
shape
[
0
],
Z_raw
[
0
].
shape
[
2
],
Z_raw
[
0
].
shape
[
1
],
self
.
n_head
),
device
=
Z_raw
[
0
].
device
,
dtype
=
Z_raw
[
0
].
dtype
)
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
z
=
Z_raw
[
0
][:,
:,
i
:
i
+
chunk_size
,
:].
transpose
(
-
2
,
-
3
)
z
=
self
.
layernorm1
(
z
)
b
[:,
i
:
i
+
chunk_size
,
:,
:]
=
self
.
linear_b
(
z
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
b
=
gather_async_opp
(
b
,
work
,
dim
=
1
)
b
=
rearrange
(
b
,
'b q k h -> b h q k'
)
dropout_mask
=
torch
.
ones_like
(
Z_raw
[
0
][:,
:,
0
:
1
,
:],
device
=
z
.
device
,
dtype
=
z
.
dtype
)
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
z_raw
=
Z_raw
[
0
][:,
:,
i
:
i
+
chunk_size
,
:]
z
=
self
.
layernorm1
(
z_raw
.
transpose
(
-
2
,
-
3
))
z_mask
=
Z_mask
[:,
:,
i
:
i
+
chunk_size
].
transpose
(
-
1
,
-
2
)
z
=
self
.
attention
(
z
,
z_mask
,
(
b
,
-
1
)).
transpose
(
-
2
,
-
3
)
z
=
bias_dropout_add
(
z
,
self
.
out_bias
,
dropout_mask
,
z_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
Z_raw
[
0
][:,
:,
i
:
i
+
chunk_size
,
:]
=
z
return
Z_raw
class
ChunkMSAColumnGlobalAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
c
=
8
,
n_head
=
8
):
super
(
ChunkMSAColumnGlobalAttention
,
self
).
__init__
()
...
...
@@ -754,6 +980,8 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
chunk_size
=
CHUNK_SIZE
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
m
=
M_raw
[:,
:,
i
:
i
+
chunk_size
,
:].
transpose
(
-
2
,
-
3
)
m
=
self
.
layernormM
(
m
)
m_mask
=
M_mask
[:,
:,
i
:
i
+
chunk_size
].
transpose
(
-
1
,
-
2
)
...
...
@@ -763,6 +991,26 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
return
M_raw
def
inplace
(
self
,
M_raw
,
M_mask
):
para_dim
=
M_raw
[
0
].
shape
[
2
]
if
CHUNK_SIZE
is
None
:
chunk_size
=
para_dim
else
:
chunk_size
=
CHUNK_SIZE
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
m
=
M_raw
[
0
][:,
:,
i
:
i
+
chunk_size
,
:].
transpose
(
-
2
,
-
3
)
m
=
self
.
layernormM
(
m
)
m_mask
=
M_mask
[:,
:,
i
:
i
+
chunk_size
].
transpose
(
-
1
,
-
2
)
m
=
self
.
global_attention
(
m
,
m_mask
)
m
=
m
.
transpose
(
-
2
,
-
3
)
M_raw
[
0
][:,
:,
i
:
i
+
chunk_size
,
:]
+=
m
return
M_raw
class
RecyclingEmbedder
(
nn
.
Module
):
"""
...
...
fastfold/model/fastnn/triangle.py
View file @
2a67dc33
...
...
@@ -247,3 +247,21 @@ class PairStack(nn.Module):
pair
=
self
.
PairTransition
(
pair
)
pair
=
col_to_row
(
pair
)
return
pair
def
inplace
(
self
,
pair
,
pair_mask
):
pair_mask_row
=
scatter
(
pair_mask
,
dim
=
1
)
pair_mask_col
=
scatter
(
pair_mask
,
dim
=
2
)
pair
[
0
]
=
self
.
TriangleMultiplicationOutgoing
(
pair
[
0
],
pair_mask_row
)
pair
[
0
]
=
row_to_col
(
pair
[
0
])
pair
[
0
]
=
self
.
TriangleMultiplicationIncoming
(
pair
[
0
],
pair_mask_col
)
pair
[
0
]
=
col_to_row
(
pair
[
0
])
pair
=
self
.
TriangleAttentionStartingNode
.
inplace
(
pair
,
pair_mask_row
)
pair
[
0
]
=
row_to_col
(
pair
[
0
])
pair
=
self
.
TriangleAttentionEndingNode
.
inplace
(
pair
,
pair_mask_col
)
pair
=
self
.
PairTransition
.
inplace
(
pair
)
pair
[
0
]
=
col_to_row
(
pair
[
0
])
return
pair
\ No newline at end of file
fastfold/model/hub/alphafold.py
View file @
2a67dc33
...
...
@@ -281,7 +281,8 @@ class AlphaFold(nn.Module):
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
self
.
globals
.
chunk_size
self
.
globals
.
chunk_size
,
inplace
=
self
.
globals
.
inplace
)
if
(
...
...
@@ -320,6 +321,7 @@ class AlphaFold(nn.Module):
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
# [*, N, N, C_z]
if
not
self
.
globals
.
inplace
or
self
.
globals
.
is_multimer
:
z
=
self
.
extra_msa_stack
(
extra_msa_feat
,
z
,
...
...
@@ -328,12 +330,24 @@ class AlphaFold(nn.Module):
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
)
else
:
extra_msa_feat
=
[
extra_msa_feat
]
z
=
[
z
]
z
=
self
.
extra_msa_stack
.
inplace
(
extra_msa_feat
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
extra_msa_feat
[
0
].
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
pair_mask
=
pair_mask
.
to
(
dtype
=
z
[
0
].
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
)[
0
]
del
extra_msa_feat
,
extra_msa_fn
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if
not
self
.
globals
.
inplace
or
self
.
globals
.
is_multimer
:
m
,
z
,
s
=
self
.
evoformer
(
m
,
z
,
...
...
@@ -342,6 +356,19 @@ class AlphaFold(nn.Module):
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
else
:
m
=
[
m
]
z
=
[
z
]
m
,
z
,
s
=
self
.
evoformer
.
inplace
(
m
,
z
,
msa_mask
=
msa_mask
.
to
(
dtype
=
m
[
0
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
[
0
].
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
m
=
m
[
0
]
z
=
z
[
0
]
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"pair"
]
=
z
...
...
fastfold/model/nn/embedders.py
View file @
2a67dc33
...
...
@@ -162,7 +162,8 @@ class TemplateEmbedder(nn.Module):
pair_mask
,
templ_dim
,
chunk_size
,
_mask_trans
=
True
_mask_trans
=
True
,
inplace
=
False
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
...
...
@@ -205,14 +206,25 @@ class TemplateEmbedder(nn.Module):
# single_template_embeds.update({"pair": t})
template_embeds
.
append
(
single_template_embeds
)
# [*, S_t, N, N, C_z]
if
inplace
:
tt
=
[
tt
]
t
[
i
]
=
self
.
template_pair_stack
.
inplace
(
tt
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)[
0
].
to
(
t
.
device
)
else
:
t
[
i
]
=
self
.
template_pair_stack
(
tt
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
).
to
(
t
.
device
)
del
tt
,
single_template_embeds
,
single_template_feats
del
tt
,
single_template_feats
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
...
...
@@ -220,21 +232,17 @@ class TemplateEmbedder(nn.Module):
)
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
.
to
(
z
.
device
),
z
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
*
256
if
chunk_size
is
not
None
else
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
ret
=
{}
if
self
.
config
.
embed_angles
:
ret
[
"template_single_embedding"
]
=
template_embeds
[
"angle"
]
z
+=
t
return
ret
,
z
...
...
fastfold/model/nn/evoformer.py
View file @
2a67dc33
...
...
@@ -533,6 +533,60 @@ class EvoformerStack(nn.Module):
return
m
,
z
,
s
def
inplace
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks
=
[
partial
(
b
.
inplace
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
if
(
self
.
clear_cache_between_blocks
):
def
block_with_cache_clear
(
block
,
*
args
):
torch
.
cuda
.
empty_cache
()
return
block
(
*
args
)
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
m
,
z
=
checkpoint_blocks
(
blocks
,
args
=
(
m
,
z
),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
s
=
self
.
linear
(
m
[
0
][...,
0
,
:,
:])
return
m
,
z
,
s
class
ExtraMSAStack
(
nn
.
Module
):
"""
...
...
@@ -626,3 +680,49 @@ class ExtraMSAStack(nn.Module):
torch
.
cuda
.
empty_cache
()
return
z
def
inplace
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for
b
in
self
.
blocks
:
m
,
z
=
b
.
inplace
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
if
(
self
.
clear_cache_between_blocks
):
torch
.
cuda
.
empty_cache
()
return
z
\ No newline at end of file
fastfold/model/nn/template.py
View file @
2a67dc33
...
...
@@ -122,11 +122,26 @@ class TemplatePointwiseAttention(nn.Module):
# [*, N_res, N_res, 1, C_z]
biases
=
[
bias
]
if
chunk_size
is
not
None
:
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
)
para_dim_t0
=
t
.
shape
[
0
]
para_dim_t1
=
t
.
shape
[
1
]
chunk_size_t
=
chunk_size
*
4
mask
=
torch
.
sum
(
template_mask
.
to
(
z
.
device
))
>
0
for
ti
in
range
(
0
,
para_dim_t0
,
chunk_size_t
):
t0
=
t
[
ti
:
ti
+
chunk_size_t
,
:,
:,
:]
t0
=
t0
.
to
(
z
.
device
)
para_dim_t_part
=
t0
.
shape
[
0
]
for
i
in
range
(
0
,
para_dim_t_part
,
chunk_size
):
for
j
in
range
(
0
,
para_dim_t1
,
chunk_size
):
z
[
i
:
i
+
chunk_size
,
j
:
j
+
chunk_size
,
:,
:]
+=
self
.
mha
(
q_x
=
z
[
i
+
ti
:
i
+
ti
+
chunk_size
,
j
:
j
+
chunk_size
,
:,
:],
kv_x
=
t0
[
i
:
i
+
chunk_size
,
j
:
j
+
chunk_size
,
:,
:],
biases
=
biases
)
*
mask
else
:
z
=
self
.
mha
(
q_x
=
z
,
kv_x
=
t
,
biases
=
biases
)
t
=
self
.
mha
(
q_x
=
z
,
kv_x
=
t
,
biases
=
biases
)
# [*, N_res, N_res, C_z]
t
=
t
*
(
torch
.
sum
(
template_mask
)
>
0
)
z
=
z
+
t
z
=
z
.
squeeze
(
-
2
)
return
z
...
...
@@ -358,3 +373,43 @@ class TemplatePairStack(nn.Module):
for
i
in
range
(
0
,
t
.
shape
[
0
],
chunk_size
):
t
[
i
:
i
+
chunk_size
]
=
self
.
layer_norm
(
t
[
i
:
i
+
chunk_size
])
return
t
def
inplace
(
self
,
t
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
chunk_size
:
int
,
_mask_trans
:
bool
=
True
,
):
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
mask:
[*, N_templ, N_res, N_res] mask
Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
"""
if
(
mask
.
shape
[
-
3
]
==
1
):
expand_idx
=
list
(
mask
.
shape
)
expand_idx
[
-
3
]
=
t
[
0
].
shape
[
-
4
]
mask
=
mask
.
expand
(
*
expand_idx
)
t
,
=
checkpoint_blocks
(
blocks
=
[
partial
(
b
.
inplace
,
mask
=
mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
],
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
if
chunk_size
is
None
:
chunk_size
=
t
[
0
].
shape
[
0
]
for
i
in
range
(
0
,
t
[
0
].
shape
[
0
],
chunk_size
):
t
[
0
][
i
:
i
+
chunk_size
]
=
self
.
layer_norm
(
t
[
0
][
i
:
i
+
chunk_size
].
to
(
mask
.
device
)).
to
(
t
[
0
].
device
)
return
t
inference.py
View file @
2a67dc33
...
...
@@ -101,6 +101,7 @@ def add_data_args(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--release_dates_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--chunk_size'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--enable_workflow'
,
default
=
False
,
action
=
'store_true'
,
help
=
'run inference with ray workflow or not'
)
parser
.
add_argument
(
'--inplace'
,
default
=
False
,
action
=
'store_true'
)
def
inference_model
(
rank
,
world_size
,
result_q
,
batch
,
args
):
...
...
@@ -113,6 +114,7 @@ def inference_model(rank, world_size, result_q, batch, args):
config
=
model_config
(
args
.
model_name
)
if
args
.
chunk_size
:
config
.
globals
.
chunk_size
=
args
.
chunk_size
config
.
globals
.
inplace
=
args
.
inplace
model
=
AlphaFold
(
config
)
import_jax_weights_
(
model
,
args
.
param_path
,
version
=
args
.
model_name
)
...
...
inference.sh
View file @
2a67dc33
# add `--gpus [N]` to use N gpus for inference
# add `--enable_workflow` to use parallel workflow for data processing
# add `--use_precomputed_alignments [path_to_alignments]` to use precomputed msa
# add '--gpus [N]' to use N gpus for inference
# add '--enable_workflow' to use parallel workflow for data processing
# add '--use_precomputed_alignments [path_to_alignments]' to use precomputed msa
# add '--chunk_size [N]' to use chunk to reduce peak memory
# add '--inplace' to use inplace to save memory
python inference.py target.fasta data/pdb_mmcif/mmcif_files
\
--output_dir
./
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment