Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
1efccb6c
Unverified
Commit
1efccb6c
authored
Aug 30, 2022
by
LuGY
Committed by
GitHub
Aug 30, 2022
Browse files
add faster extraMSA and templateStack (#53)
* add faster extraMSA and templateStack * add license
parent
369f3e70
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
600 additions
and
81 deletions
+600
-81
fastfold/model/fastnn/__init__.py
fastfold/model/fastnn/__init__.py
+4
-2
fastfold/model/fastnn/blocks.py
fastfold/model/fastnn/blocks.py
+268
-0
fastfold/model/fastnn/msa.py
fastfold/model/fastnn/msa.py
+64
-1
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+82
-0
fastfold/utils/inject_fastnn.py
fastfold/utils/inject_fastnn.py
+182
-78
No files found.
fastfold/model/fastnn/__init__.py
View file @
1efccb6c
from
.msa
import
MSAStack
from
.msa
import
MSAStack
,
ExtraMSAStack
from
.ops
import
OutProductMean
,
set_chunk_size
from
.triangle
import
PairStack
from
.evoformer
import
Evoformer
from
.blocks
import
EvoformerBlock
,
ExtraMSABlock
,
TemplatePairStackBlock
__all__
=
[
'MSAStack'
,
'OutProductMean'
,
'PairStack'
,
'Evoformer'
,
'set_chunk_size'
]
__all__
=
[
'MSAStack'
,
'ExtraMSAStack'
,
'OutProductMean'
,
'PairStack'
,
'Evoformer'
,
'set_chunk_size'
,
'EvoformerBlock'
,
'ExtraMSABlock'
,
'TemplatePairStackBlock'
]
fastfold/model/fastnn/blocks.py
0 → 100644
View file @
1efccb6c
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
fastfold.model.fastnn
import
MSAStack
,
OutProductMean
,
PairStack
,
ExtraMSAStack
from
fastfold.model.fastnn.ops
import
Transition
from
fastfold.model.fastnn.triangle
import
TriangleAttentionEndingNode
,
TriangleAttentionStartingNode
,
\
TriangleMultiplicationIncoming
,
TriangleMultiplicationOutgoing
from
fastfold.distributed.comm
import
gather
,
scatter
from
fastfold.distributed.comm
import
col_to_row
,
row_to_col
,
scatter
from
fastfold.distributed.comm_async
import
All_to_All_Async
,
All_to_All_Async_Opp
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
first_block
:
bool
,
last_block
:
bool
):
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
first_block
=
first_block
self
.
last_block
=
last_block
self
.
msa_stack
=
MSAStack
(
c_m
,
c_z
,
p_drop
=
0.15
)
self
.
communication
=
OutProductMean
(
n_feat
=
c_m
,
n_feat_out
=
c_z
,
n_feat_proj
=
32
)
self
.
pair_stack
=
PairStack
(
d_pair
=
c_z
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
seq_length
=
pair_mask
.
size
(
-
1
)
padding_size
=
(
int
(
seq_length
/
dap_size
)
+
1
)
*
dap_size
-
seq_length
if
self
.
first_block
:
m
=
m
.
unsqueeze
(
0
)
z
=
z
.
unsqueeze
(
0
)
m
=
torch
.
nn
.
functional
.
pad
(
m
,
(
0
,
0
,
0
,
padding_size
))
z
=
torch
.
nn
.
functional
.
pad
(
z
,
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
m
=
scatter
(
m
,
dim
=
1
)
z
=
scatter
(
z
,
dim
=
1
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
msa_mask
=
torch
.
nn
.
functional
.
pad
(
msa_mask
,
(
0
,
padding_size
))
pair_mask
=
torch
.
nn
.
functional
.
pad
(
pair_mask
,
(
0
,
padding_size
,
0
,
padding_size
))
m
=
self
.
msa_stack
(
m
,
z
,
msa_mask
)
z
=
z
+
self
.
communication
(
m
,
msa_mask
)
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
if
self
.
last_block
:
m
=
m
.
squeeze
(
0
)
z
=
z
.
squeeze
(
0
)
m
=
gather
(
m
,
dim
=
0
)
z
=
gather
(
z
,
dim
=
0
)
m
=
m
[:,
:
-
padding_size
,
:]
z
=
z
[:
-
padding_size
,
:
-
padding_size
,
:]
return
m
,
z
class
ExtraMSABlock
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
first_block
:
bool
,
last_block
:
bool
,
is_multimer
=
False
):
super
(
ExtraMSABlock
,
self
).
__init__
()
self
.
first_block
=
first_block
self
.
last_block
=
last_block
self
.
msa_stack
=
ExtraMSAStack
(
c_m
,
c_z
,
p_drop
=
0.15
)
self
.
communication
=
OutProductMean
(
n_feat
=
c_m
,
n_feat_out
=
c_z
,
n_feat_proj
=
32
)
self
.
pair_stack
=
PairStack
(
d_pair
=
c_z
)
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
seq_cnt
=
msa_mask
.
size
(
-
2
)
seq_len
=
pair_mask
.
size
(
-
1
)
seq_cnt_padding_size
=
(
int
(
seq_cnt
/
dap_size
)
+
1
)
*
dap_size
-
seq_cnt
seq_len_padding_size
=
(
int
(
seq_len
/
dap_size
)
+
1
)
*
dap_size
-
seq_len
if
self
.
first_block
:
m
=
m
.
unsqueeze
(
0
)
z
=
z
.
unsqueeze
(
0
)
m
=
torch
.
nn
.
functional
.
pad
(
m
,
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
)
z
=
torch
.
nn
.
functional
.
pad
(
z
,
(
0
,
0
,
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
)
m
=
scatter
(
m
,
dim
=
1
)
if
not
self
.
is_multimer
else
scatter
(
m
,
dim
=
2
)
z
=
scatter
(
z
,
dim
=
1
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
msa_mask
=
torch
.
nn
.
functional
.
pad
(
msa_mask
,
(
0
,
seq_len_padding_size
,
0
,
seq_cnt_padding_size
)
)
pair_mask
=
torch
.
nn
.
functional
.
pad
(
pair_mask
,
(
0
,
seq_len_padding_size
,
0
,
seq_len_padding_size
)
)
if
not
self
.
is_multimer
:
m
=
self
.
msa_stack
(
m
,
z
,
msa_mask
)
z
=
z
+
self
.
communication
(
m
,
msa_mask
)
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
else
:
z
=
z
+
self
.
communication
(
m
,
msa_mask
)
z_ori
=
z
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
m
=
self
.
msa_stack
(
m
,
z_ori
,
msa_mask
)
if
self
.
last_block
:
m
=
gather
(
m
,
dim
=
1
)
if
not
self
.
is_multimer
else
gather
(
m
,
dim
=
2
)
z
=
gather
(
z
,
dim
=
1
)
m
=
m
[:,
:
-
seq_cnt_padding_size
,
:
-
seq_len_padding_size
,
:]
z
=
z
[:,
:
-
seq_len_padding_size
,
:
-
seq_len_padding_size
,
:]
m
=
m
.
squeeze
(
0
)
z
=
z
.
squeeze
(
0
)
return
m
,
z
class
TemplatePairStackBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_t
:
int
,
c_hidden_tri_att
:
int
,
c_hidden_tri_mul
:
int
,
no_heads
:
int
,
pair_transition_n
:
int
,
dropout_rate
:
float
,
inf
:
float
,
first_block
:
bool
,
last_block
:
bool
,
**
kwargs
,
):
super
(
TemplatePairStackBlock
,
self
).
__init__
()
self
.
first_block
=
first_block
self
.
last_block
=
last_block
self
.
c_t
=
c_t
self
.
c_hidden_tri_att
=
c_hidden_tri_att
self
.
c_hidden_tri_mul
=
c_hidden_tri_mul
self
.
n_head
=
no_heads
self
.
p_drop
=
dropout_rate
self
.
hidden_c
=
int
(
c_t
/
self
.
n_head
)
self
.
TriangleMultiplicationOutgoing
=
TriangleMultiplicationOutgoing
(
self
.
c_t
,
p_drop
=
self
.
p_drop
,
c
=
self
.
c_hidden_tri_mul
)
self
.
TriangleMultiplicationIncoming
=
TriangleMultiplicationIncoming
(
self
.
c_t
,
p_drop
=
self
.
p_drop
,
c
=
self
.
c_hidden_tri_mul
)
self
.
TriangleAttentionStartingNode
=
TriangleAttentionStartingNode
(
self
.
c_t
,
p_drop
=
self
.
p_drop
,
c
=
self
.
c_hidden_tri_att
,
n_head
=
self
.
n_head
)
self
.
TriangleAttentionEndingNode
=
TriangleAttentionEndingNode
(
self
.
c_t
,
p_drop
=
self
.
p_drop
,
c
=
self
.
c_hidden_tri_att
,
n_head
=
self
.
n_head
)
self
.
PairTransition
=
Transition
(
d
=
self
.
c_t
,
n
=
pair_transition_n
)
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
):
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
seq_length
=
mask
.
size
(
-
1
)
padding_size
=
(
int
(
seq_length
/
dap_size
)
+
1
)
*
dap_size
-
seq_length
if
self
.
first_block
:
z
=
torch
.
nn
.
functional
.
pad
(
z
,
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
z
=
scatter
(
z
,
dim
=
1
)
mask
=
torch
.
nn
.
functional
.
pad
(
mask
,
(
0
,
padding_size
,
0
,
padding_size
))
single_templates
=
[
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)]
single_templates_masks
=
[
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)]
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
single_mask_row
=
scatter
(
single_mask
,
dim
=
1
)
single_mask_col
=
scatter
(
single_mask
,
dim
=
2
)
single
=
self
.
TriangleMultiplicationOutgoing
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
self
.
TriangleMultiplicationIncoming
(
single
,
single_mask_col
)
single
=
col_to_row
(
single
)
single
=
self
.
TriangleAttentionStartingNode
(
single
,
single_mask_row
)
single
=
row_to_col
(
single
)
single
=
self
.
TriangleAttentionEndingNode
(
single
,
single_mask_col
)
single
=
self
.
PairTransition
(
single
)
single
=
col_to_row
(
single
)
single_templates
[
i
]
=
single
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
if
self
.
last_block
:
z
=
gather
(
z
,
dim
=
1
)
z
=
z
[:,
:
-
padding_size
,
:
-
padding_size
,
:]
return
z
\ No newline at end of file
fastfold/model/fastnn/msa.py
View file @
1efccb6c
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
torch
...
...
@@ -5,7 +19,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
fastfold.model.fastnn.kernel
import
LayerNorm
from
fastfold.model.fastnn.ops
import
Transition
,
SelfAttention
from
fastfold.model.fastnn.ops
import
Transition
,
SelfAttention
,
GlobalAttention
from
fastfold.model.fastnn.kernel
import
bias_dropout_add
from
fastfold.distributed
import
scatter
,
row_to_col
from
fastfold.distributed.comm_async
import
gather_async
...
...
@@ -81,6 +95,31 @@ class MSAColumnAttention(nn.Module):
return
M_raw
+
M
class
MSAColumnGlobalAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
c
=
8
,
n_head
=
8
):
super
(
MSAColumnGlobalAttention
,
self
).
__init__
()
self
.
d_node
=
d_node
self
.
c
=
c
self
.
n_head
=
n_head
self
.
layernormM
=
LayerNorm
(
d_node
)
self
.
global_attention
=
GlobalAttention
(
qkv_dim
=
d_node
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_node
)
def
forward
(
self
,
M_raw
,
M_mask
):
M
=
M_raw
.
transpose
(
-
2
,
-
3
)
M
=
self
.
layernormM
(
M
)
M_mask
=
M_mask
.
transpose
(
-
1
,
-
2
)
M
=
self
.
global_attention
(
M
,
M_mask
)
M
=
M
.
transpose
(
-
2
,
-
3
)
return
M_raw
+
M
class
MSAStack
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
,
p_drop
=
0.15
):
...
...
@@ -105,3 +144,27 @@ class MSAStack(nn.Module):
node
=
self
.
MSATransition
(
node
)
return
node
class
ExtraMSAStack
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
,
p_drop
=
0.15
):
super
(
ExtraMSAStack
,
self
).
__init__
()
self
.
MSARowAttentionWithPairBias
=
MSARowAttentionWithPairBias
(
d_node
=
d_node
,
d_pair
=
d_pair
,
p_drop
=
p_drop
,
c
=
8
)
self
.
MSAColumnAttention
=
MSAColumnGlobalAttention
(
d_node
=
d_node
,
c
=
8
)
self
.
MSATransition
=
Transition
(
d
=
d_node
)
def
forward
(
self
,
node
,
pair
,
node_mask
):
node_mask_row
=
scatter
(
node_mask
,
dim
=
1
)
node
=
self
.
MSARowAttentionWithPairBias
(
node
,
pair
,
node_mask_row
)
node
=
row_to_col
(
node
)
node_mask_col
=
scatter
(
node_mask
,
dim
=
2
)
node
=
self
.
MSAColumnAttention
(
node
,
node_mask_col
)
node
=
self
.
MSATransition
(
node
)
return
node
fastfold/model/fastnn/ops.py
View file @
1efccb6c
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -219,3 +233,71 @@ class SelfAttention(nn.Module):
output
=
torch
.
cat
(
output
,
dim
=
1
)
return
output
class
GlobalAttention
(
nn
.
Module
):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def
__init__
(
self
,
qkv_dim
,
c
,
n_head
,
out_dim
):
super
(
GlobalAttention
,
self
).
__init__
()
self
.
qkv_dim
=
qkv_dim
self
.
c
=
c
self
.
n_head
=
n_head
self
.
out_dim
=
out_dim
self
.
scaling
=
self
.
c
**
(
-
0.5
)
self
.
eps
=
1e-10
self
.
inf
=
1e9
self
.
to_q
=
Linear
(
qkv_dim
,
c
*
self
.
n_head
,
use_bias
=
False
)
self
.
to_kv
=
Linear
(
qkv_dim
,
2
*
c
,
initializer
=
"linear"
,
use_bias
=
False
)
self
.
gating_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
ones
((
n_head
*
c
,)))
self
.
gating_linear
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
"zero"
,
use_bias
=
False
)
self
.
o_linear
=
Linear
(
n_head
*
c
,
out_dim
,
initializer
=
"zero"
)
def
forward
(
self
,
m
,
mask
):
para_dim
=
m
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
m_part
=
m
[:,
ax
:
ax
+
chunk_size
,
:,
:]
mask_part
=
mask
[:,
ax
:
ax
+
chunk_size
,
:]
q
=
torch
.
sum
(
m_part
*
mask_part
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask_part
,
dim
=-
1
)[...,
None
]
+
self
.
eps
)
q
=
self
.
to_q
(
q
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
n_head
,
-
1
))
k
,
v
=
self
.
to_kv
(
m_part
).
chunk
(
2
,
dim
=-
1
)
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
weights
=
mask_softmax
(
logits
,
mask_part
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
"b1 b2 h d -> b1 b2 (h d)"
)
gate_values
=
self
.
gating_linear
(
m_part
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
.
unsqueeze
(
-
2
)
)
output
.
append
(
self
.
o_linear
(
weighted_avg
))
m
=
torch
.
cat
(
output
,
dim
=
1
)
return
m
\ No newline at end of file
fastfold/utils/inject_fastnn.py
View file @
1efccb6c
from
typing
import
Tuple
,
Optional
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
fastfold.model.fastnn
import
MSAStack
,
OutProductMean
,
PairStack
from
fastfold.distributed.comm_async
import
All_to_All_Async
,
All_to_All_Async_Opp
from
fastfold.distributed.comm
import
gather
,
scatter
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
first_block
:
bool
,
last_block
:
bool
):
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
first_block
=
first_block
self
.
last_block
=
last_block
self
.
msa_stack
=
MSAStack
(
c_m
,
c_z
,
p_drop
=
0.15
)
self
.
communication
=
OutProductMean
(
n_feat
=
c_m
,
n_feat_out
=
c_z
,
n_feat_proj
=
32
)
self
.
pair_stack
=
PairStack
(
d_pair
=
c_z
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
seq_length
=
pair_mask
.
size
(
-
1
)
padding_size
=
(
int
(
seq_length
/
dap_size
)
+
1
)
*
dap_size
-
seq_length
if
self
.
first_block
:
m
=
m
.
unsqueeze
(
0
)
z
=
z
.
unsqueeze
(
0
)
m
=
torch
.
nn
.
functional
.
pad
(
m
,
(
0
,
0
,
0
,
padding_size
))
z
=
torch
.
nn
.
functional
.
pad
(
z
,
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
m
=
scatter
(
m
,
dim
=
1
)
z
=
scatter
(
z
,
dim
=
1
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
msa_mask
=
torch
.
nn
.
functional
.
pad
(
msa_mask
,
(
0
,
padding_size
))
pair_mask
=
torch
.
nn
.
functional
.
pad
(
pair_mask
,
(
0
,
padding_size
,
0
,
padding_size
))
m
=
self
.
msa_stack
(
m
,
z
,
msa_mask
)
z
=
z
+
self
.
communication
(
m
,
msa_mask
)
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
if
self
.
last_block
:
m
=
m
.
squeeze
(
0
)
z
=
z
.
squeeze
(
0
)
m
=
gather
(
m
,
dim
=
0
)
z
=
gather
(
z
,
dim
=
0
)
m
=
m
[:,
:
-
padding_size
,
:]
z
=
z
[:
-
padding_size
,
:
-
padding_size
,
:]
return
m
,
z
from
fastfold.model.fastnn
import
EvoformerBlock
,
ExtraMSABlock
,
TemplatePairStackBlock
def
copy_layernorm
(
model_fast
,
model_ori
):
model_fast
.
weight
.
copy_
(
model_ori
.
weight
)
model_fast
.
bias
.
copy_
(
model_ori
.
bias
)
...
...
@@ -85,6 +27,10 @@ def copy_linear(model_fast, model_ori):
model_fast
.
bias
.
copy_
(
model_ori
.
bias
)
def
copy_kv_linear
(
model_fast
,
ori_k
,
ori_v
):
model_fast
.
weight
.
copy_
(
torch
.
cat
((
ori_k
.
weight
,
ori_v
.
weight
),
dim
=
0
))
def
copy_qkv_linear
(
model_fast
,
ori_q
,
ori_k
,
ori_v
):
model_fast
.
weight
.
copy_
(
torch
.
cat
((
ori_q
.
weight
,
ori_k
.
weight
,
ori_v
.
weight
),
dim
=
0
))
...
...
@@ -131,7 +77,7 @@ def copy_triangle_att(model_fast, model_ori):
model_fast
.
out_bias
.
copy_
(
model_ori
.
mha
.
linear_o
.
bias
)
def
copy_para
(
block_fast
,
block_ori
):
def
copy_
evoformer_
para
(
block_fast
,
block_ori
):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
layernormM
,
...
...
@@ -179,7 +125,104 @@ def copy_para(block_fast, block_ori):
copy_transition
(
block_fast
.
pair_stack
.
PairTransition
,
block_ori
.
core
.
pair_transition
)
def
inject_fastnn
(
model
):
def
copy_global_attention
(
model_fast
,
model_ori
):
copy_linear
(
model_fast
.
to_q
,
model_ori
.
linear_q
)
copy_kv_linear
(
model_fast
.
to_kv
,
model_ori
.
linear_k
,
model_ori
.
linear_v
)
copy_linear
(
model_fast
.
gating_linear
,
model_ori
.
linear_g
)
copy_linear
(
model_fast
.
o_linear
,
model_ori
.
linear_o
)
try
:
model_fast
.
gating_bias
.
copy_
(
model_ori
.
linear_g
.
bias
)
except
:
print
(
"no gating_bias need copy"
)
def
copy_extra_msa_para
(
block_fast
,
block_ori
):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
layernormM
,
block_ori
.
msa_att_row
.
layer_norm_m
,
)
copy_layernorm
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
layernormZ
,
block_ori
.
msa_att_row
.
layer_norm_z
,
)
copy_attention
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
attention
,
block_ori
.
msa_att_row
.
mha
,
)
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
linear_b_weights
.
copy_
(
block_ori
.
msa_att_row
.
linear_z
.
weight
)
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
out_bias
.
copy_
(
block_ori
.
msa_att_row
.
mha
.
linear_o
.
bias
)
# MSAColumnAttention
copy_layernorm
(
block_fast
.
msa_stack
.
MSAColumnAttention
.
layernormM
,
block_ori
.
msa_att_col
.
layer_norm_m
,
)
copy_global_attention
(
block_fast
.
msa_stack
.
MSAColumnAttention
.
global_attention
,
block_ori
.
msa_att_col
.
global_attention
,
)
# MSATransition
copy_transition
(
block_fast
.
msa_stack
.
MSATransition
,
block_ori
.
core
.
msa_transition
)
# communication
comm_model
=
(
block_ori
.
core
.
outer_product_mean
# if not block_ori.is_multimer else block_ori.outer_product_mean
)
copy_layernorm
(
block_fast
.
communication
.
layernormM
,
comm_model
.
layer_norm
)
copy_linear
(
block_fast
.
communication
.
linear_a
,
comm_model
.
linear_1
)
copy_linear
(
block_fast
.
communication
.
linear_b
,
comm_model
.
linear_2
)
copy_linear
(
block_fast
.
communication
.
o_linear
,
comm_model
.
linear_out
)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle
(
block_fast
.
pair_stack
.
TriangleMultiplicationOutgoing
,
block_ori
.
core
.
tri_mul_out
)
# TriangleMultiplicationIncoming
copy_triangle
(
block_fast
.
pair_stack
.
TriangleMultiplicationIncoming
,
block_ori
.
core
.
tri_mul_in
)
# TriangleAttentionStartingNode
copy_triangle_att
(
block_fast
.
pair_stack
.
TriangleAttentionStartingNode
,
block_ori
.
core
.
tri_att_start
,
)
copy_triangle_att
(
block_fast
.
pair_stack
.
TriangleAttentionEndingNode
,
block_ori
.
core
.
tri_att_end
)
copy_transition
(
block_fast
.
pair_stack
.
PairTransition
,
block_ori
.
core
.
pair_transition
)
def
copy_template_pair_stack_para
(
block_fast
,
block_ori
):
# TriangleMultiplicationOutgoing
copy_triangle
(
block_fast
.
TriangleMultiplicationOutgoing
,
block_ori
.
tri_mul_out
)
# TriangleMultiplicationIncoming
copy_triangle
(
block_fast
.
TriangleMultiplicationIncoming
,
block_ori
.
tri_mul_in
)
# TriangleAttentionStartingNode
copy_triangle_att
(
block_fast
.
TriangleAttentionStartingNode
,
block_ori
.
tri_att_start
)
copy_triangle_att
(
block_fast
.
TriangleAttentionEndingNode
,
block_ori
.
tri_att_end
)
copy_transition
(
block_fast
.
PairTransition
,
block_ori
.
pair_transition
)
def
inject_evoformer
(
model
):
with
torch
.
no_grad
():
fastfold_blocks
=
nn
.
ModuleList
()
for
block_id
,
ori_block
in
enumerate
(
model
.
evoformer
.
blocks
):
...
...
@@ -188,13 +231,74 @@ def inject_fastnn(model):
fastfold_block
=
EvoformerBlock
(
c_m
=
c_m
,
c_z
=
c_z
,
first_block
=
(
block_id
==
0
),
last_block
=
(
block_id
==
len
(
model
.
evoformer
.
blocks
)
-
1
)
)
last_block
=
(
block_id
==
len
(
model
.
evoformer
.
blocks
)
-
1
)
)
copy_para
(
fastfold_block
,
ori_block
)
copy_
evoformer_
para
(
fastfold_block
,
ori_block
)
fastfold_blocks
.
append
(
fastfold_block
)
model
.
evoformer
.
blocks
=
fastfold_blocks
return
model
def
inject_extraMsaBlock
(
model
):
with
torch
.
no_grad
():
new_model_blocks
=
nn
.
ModuleList
()
for
block_id
,
ori_block
in
enumerate
(
model
.
extra_msa_stack
.
blocks
):
c_m
=
ori_block
.
msa_att_row
.
c_in
c_z
=
ori_block
.
msa_att_row
.
c_z
new_model_block
=
ExtraMSABlock
(
c_m
=
c_m
,
c_z
=
c_z
,
first_block
=
(
block_id
==
0
),
last_block
=
(
block_id
==
len
(
model
.
extra_msa_stack
.
blocks
)
-
1
),
)
copy_extra_msa_para
(
new_model_block
,
ori_block
)
if
ori_block
.
training
==
False
:
new_model_block
.
eval
()
new_model_blocks
.
append
(
new_model_block
)
model
.
extra_msa_stack
.
blocks
=
new_model_blocks
def
inject_templatePairBlock
(
model
):
with
torch
.
no_grad
():
target_module
=
model
.
template_pair_stack
.
blocks
fastfold_blocks
=
nn
.
ModuleList
()
for
block_id
,
ori_block
in
enumerate
(
target_module
):
c_t
=
ori_block
.
c_t
c_hidden_tri_att
=
ori_block
.
c_hidden_tri_att
c_hidden_tri_mul
=
ori_block
.
c_hidden_tri_mul
no_heads
=
ori_block
.
no_heads
pair_transition_n
=
ori_block
.
pair_transition_n
dropout_rate
=
ori_block
.
dropout_rate
inf
=
ori_block
.
inf
fastfold_block
=
TemplatePairStackBlock
(
c_t
=
c_t
,
c_hidden_tri_att
=
c_hidden_tri_att
,
c_hidden_tri_mul
=
c_hidden_tri_mul
,
no_heads
=
no_heads
,
pair_transition_n
=
pair_transition_n
,
dropout_rate
=
dropout_rate
,
inf
=
inf
,
first_block
=
(
block_id
==
0
),
last_block
=
(
block_id
==
len
(
target_module
)
-
1
),
)
copy_template_pair_stack_para
(
fastfold_block
,
ori_block
)
if
ori_block
.
training
==
False
:
fastfold_block
.
eval
()
fastfold_blocks
.
append
(
fastfold_block
)
model
.
template_pair_stack
.
blocks
=
fastfold_blocks
def
inject_fastnn
(
model
):
inject_evoformer
(
model
)
inject_extraMsaBlock
(
model
)
inject_templatePairBlock
(
model
)
return
model
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment