Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f7d8092c
Commit
f7d8092c
authored
Dec 29, 2022
by
oahzxl
Browse files
align openfold
parent
5c4df01a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
36 additions
and
769 deletions
+36
-769
autochunk_benchmark.py
autochunk_benchmark.py
+36
-5
evoformer_openfold/evoformer.py
evoformer_openfold/evoformer.py
+0
-59
evoformer_openfold/initializer.py
evoformer_openfold/initializer.py
+0
-29
evoformer_openfold/kernel.py
evoformer_openfold/kernel.py
+0
-19
evoformer_openfold/msa.py
evoformer_openfold/msa.py
+0
-95
evoformer_openfold/ops.py
evoformer_openfold/ops.py
+0
-176
evoformer_openfold/triangle.py
evoformer_openfold/triangle.py
+0
-192
openfold/evoformer.py
openfold/evoformer.py
+0
-194
No files found.
autochunk_benchmark.py
View file @
f7d8092c
...
@@ -9,20 +9,27 @@ from colossalai.fx.graph_module import ColoGraphModule
...
@@ -9,20 +9,27 @@ from colossalai.fx.graph_module import ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.profiler
import
MetaTensor
from
evoformer.evoformer
import
evoformer_base
from
evoformer.evoformer
import
evoformer_base
from
openfold.evoformer
import
EvoformerBlock
def
_benchmark_evoformer
(
model
:
torch
.
nn
.
Module
,
node
,
pair
,
title
):
def
_benchmark_evoformer
(
model
:
torch
.
nn
.
Module
,
node
,
pair
,
title
,
chunk_size
=
None
):
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
reset_peak_memory_stats
()
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
loop
=
16
loop
=
16
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
_
in
range
(
loop
//
4
):
for
_
in
range
(
loop
//
4
):
model
(
node
,
pair
)
if
chunk_size
:
model
(
node
,
pair
,
chunk_size
)
else
:
model
(
node
,
pair
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
time1
=
time
.
time
()
time1
=
time
.
time
()
for
_
in
range
(
loop
):
for
_
in
range
(
loop
):
model
(
node
,
pair
)
if
chunk_size
:
model
(
node
,
pair
,
chunk_size
)
else
:
model
(
node
,
pair
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
time2
=
time
.
time
()
time2
=
time
.
time
()
...
@@ -64,6 +71,26 @@ def _build_autochunk(model, max_memory, node, pair):
...
@@ -64,6 +71,26 @@ def _build_autochunk(model, max_memory, node, pair):
return
gm
return
gm
def
_build_openfold
():
model
=
EvoformerBlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
is_multimer
=
False
,
).
cuda
()
return
model
def
benchmark_evoformer
():
def
benchmark_evoformer
():
# init data and model
# init data and model
msa_len
=
300
msa_len
=
300
...
@@ -74,10 +101,14 @@ def benchmark_evoformer():
...
@@ -74,10 +101,14 @@ def benchmark_evoformer():
# build autochunk model
# build autochunk model
max_memory
=
3000
# MB
max_memory
=
3000
# MB
autochunk
=
_build_autochunk
(
model
,
max_memory
,
node
,
pair
)
autochunk
=
_build_autochunk
(
evoformer_base
().
cuda
(),
max_memory
,
node
,
pair
)
# build openfold
openfold
=
_build_openfold
()
# benchmark
# benchmark
_benchmark_evoformer
(
model
,
node
,
pair
,
"openfold"
)
_benchmark_evoformer
(
model
,
node
,
pair
,
"base"
)
_benchmark_evoformer
(
openfold
,
node
,
pair
,
"openfold"
,
chunk_size
=
4
)
_benchmark_evoformer
(
autochunk
,
node
,
pair
,
"autochunk"
)
_benchmark_evoformer
(
autochunk
,
node
,
pair
,
"autochunk"
)
...
...
evoformer_openfold/evoformer.py
deleted
100644 → 0
View file @
5c4df01a
import
torch
import
torch.nn
as
nn
from
.msa
import
MSAStack
from
.ops
import
OutProductMean
from
.triangle
import
PairStack
def
print_memory
(
init_mem
,
text
=
None
):
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
-
init_mem
max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
-
init_mem
print
(
"%s now:%.2f max:%.2f"
%
(
""
if
text
is
None
else
text
,
now_mem
,
max_mem
))
torch
.
cuda
.
reset_peak_memory_stats
()
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
):
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
msa_stack
=
MSAStack
(
d_node
,
d_pair
,
p_drop
=
0.15
)
self
.
communication
=
OutProductMean
(
n_feat
=
d_node
,
n_feat_out
=
d_pair
,
n_feat_proj
=
32
)
self
.
pair_stack
=
PairStack
(
d_pair
=
d_pair
)
def
forward
(
self
,
node
,
pair
):
node
=
self
.
msa_stack
(
node
,
pair
)
pair
=
pair
+
self
.
communication
(
node
)
pair
=
self
.
pair_stack
(
pair
)
return
node
,
pair
class
Evoformer
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
):
super
(
Evoformer
,
self
).
__init__
()
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
1
):
self
.
blocks
.
append
(
EvoformerBlock
(
d_node
,
d_pair
))
def
forward
(
self
,
node
,
pair
):
for
b
in
self
.
blocks
:
node
,
pair
=
b
(
node
,
pair
)
return
node
,
pair
def
evoformer_tiny
():
return
Evoformer
(
d_node
=
64
,
d_pair
=
32
)
def
evoformer_base
():
return
Evoformer
(
d_node
=
256
,
d_pair
=
128
)
def
evoformer_large
():
return
Evoformer
(
d_node
=
512
,
d_pair
=
256
)
__all__
=
[
'Evoformer'
,
'evoformer_base'
,
'evoformer_large'
]
evoformer_openfold/initializer.py
deleted
100755 → 0
View file @
5c4df01a
import
math
import
numpy
as
np
import
torch.nn
as
nn
def
glorot_uniform_af
(
x
,
gain
=
1.0
):
"""
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
In PyTorch:
[feature_out, feature_in, n_head ...]
In Jax:
[... n_head, feature_in, feature_out]
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
[feature_in, n_head, feature_out]
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
"""
fan_in
,
fan_out
=
x
.
shape
[
-
2
:]
if
len
(
x
.
shape
)
>
2
:
receptive_field_size
=
np
.
prod
(
x
.
shape
[:
-
2
])
fan_in
*=
receptive_field_size
fan_out
*=
receptive_field_size
std
=
gain
*
math
.
sqrt
(
2.0
/
float
(
fan_in
+
fan_out
))
dev
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
nn
.
init
.
uniform_
(
x
,
-
dev
,
dev
)
return
x
evoformer_openfold/kernel.py
deleted
100644 → 0
View file @
5c4df01a
import
torch
import
torch.nn.functional
as
F
def
bias_sigmod_ele
(
y
,
bias
,
z
):
return
torch
.
sigmoid
(
y
+
bias
)
*
z
def
bias_dropout_add
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
dropmask
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
out
=
(
x
+
bias
)
*
F
.
dropout
(
dropmask
,
p
=
prob
,
training
=
False
)
out
=
residual
+
out
return
out
def
bias_ele_dropout_residual
(
ab
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
dropout_mask
:
torch
.
Tensor
,
Z_raw
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
Z_raw
+
F
.
dropout
(
dropout_mask
,
p
=
prob
,
training
=
True
)
*
(
g
*
(
ab
+
b
))
\ No newline at end of file
evoformer_openfold/msa.py
deleted
100644 → 0
View file @
5c4df01a
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn
import
LayerNorm
from
.kernel
import
bias_dropout_add
from
.ops
import
SelfAttention
,
Transition
class
MSARowAttentionWithPairBias
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
,
c
=
32
,
n_head
=
8
,
p_drop
=
0.15
):
super
(
MSARowAttentionWithPairBias
,
self
).
__init__
()
self
.
d_node
=
d_node
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernormM
=
LayerNorm
(
d_node
)
self
.
layernormZ
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
n_head
,
d_pair
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
,
requires_grad
=
True
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_node
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_node
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_node
,)),
requires_grad
=
True
)
def
forward
(
self
,
M_raw
,
Z
):
## Input projections
M
=
self
.
layernormM
(
M_raw
)
Z
=
self
.
layernormZ
(
Z
)
b
=
F
.
linear
(
Z
,
self
.
linear_b_weights
)
b
=
b
.
permute
(
0
,
3
,
1
,
2
)
# b = rearrange(b, 'b q k h -> b h q k')
M
=
self
.
attention
(
M
,
b
)
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:]).
to
(
M
.
device
).
to
(
M
.
dtype
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
)
class
MSAColumnAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
c
=
32
,
n_head
=
8
):
super
(
MSAColumnAttention
,
self
).
__init__
()
self
.
d_node
=
d_node
self
.
c
=
c
self
.
n_head
=
n_head
self
.
layernormM
=
LayerNorm
(
d_node
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_node
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_node
,
gating
=
True
)
def
forward
(
self
,
M_raw
):
M
=
M_raw
.
transpose
(
-
2
,
-
3
)
M
=
self
.
layernormM
(
M
)
M
=
self
.
attention
(
M
)
M
=
M
.
transpose
(
-
2
,
-
3
)
return
M_raw
+
M
class
MSAStack
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
,
p_drop
=
0.15
):
super
(
MSAStack
,
self
).
__init__
()
self
.
MSARowAttentionWithPairBias
=
MSARowAttentionWithPairBias
(
d_node
=
d_node
,
d_pair
=
d_pair
,
p_drop
=
p_drop
)
self
.
MSAColumnAttention
=
MSAColumnAttention
(
d_node
=
d_node
)
self
.
MSATransition
=
Transition
(
d
=
d_node
)
def
forward
(
self
,
node
,
pair
):
node
=
self
.
MSARowAttentionWithPairBias
(
node
,
pair
)
node
=
self
.
MSAColumnAttention
(
node
)
node
=
self
.
MSATransition
(
node
)
return
node
evoformer_openfold/ops.py
deleted
100755 → 0
View file @
5c4df01a
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn
import
LayerNorm
from
.initializer
import
glorot_uniform_af
from
.kernel
import
bias_sigmod_ele
class
DropoutRowwise
(
nn
.
Module
):
def
__init__
(
self
,
p
):
super
(
DropoutRowwise
,
self
).
__init__
()
self
.
p
=
p
self
.
dropout
=
nn
.
Dropout
(
p
=
p
)
def
forward
(
self
,
x
):
dropout_mask
=
torch
.
ones_like
(
x
[:,
0
:
1
,
:,
:])
dropout_mask
=
self
.
dropout
(
dropout_mask
)
return
dropout_mask
*
x
class
DropoutColumnwise
(
nn
.
Module
):
def
__init__
(
self
,
p
):
super
(
DropoutColumnwise
,
self
).
__init__
()
self
.
p
=
p
self
.
dropout
=
nn
.
Dropout
(
p
=
p
)
def
forward
(
self
,
x
):
dropout_mask
=
torch
.
ones_like
(
x
[:,
:,
0
:
1
,
:])
dropout_mask
=
self
.
dropout
(
dropout_mask
)
return
dropout_mask
*
x
class
Transition
(
nn
.
Module
):
def
__init__
(
self
,
d
,
n
=
4
):
super
(
Transition
,
self
).
__init__
()
self
.
norm
=
LayerNorm
(
d
)
self
.
linear1
=
Linear
(
d
,
n
*
d
,
initializer
=
'relu'
)
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
def
forward
(
self
,
src
):
x
=
self
.
norm
(
src
)
x
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
x
)))
return
src
+
x
class
OutProductMean
(
nn
.
Module
):
def
__init__
(
self
,
n_feat
=
64
,
n_feat_out
=
128
,
n_feat_proj
=
32
):
super
(
OutProductMean
,
self
).
__init__
()
self
.
layernormM
=
LayerNorm
(
n_feat
)
self
.
linear_a
=
Linear
(
n_feat
,
n_feat_proj
)
self
.
linear_b
=
Linear
(
n_feat
,
n_feat_proj
)
self
.
o_linear
=
Linear
(
n_feat_proj
*
n_feat_proj
,
n_feat_out
,
initializer
=
'zero'
,
use_bias
=
True
)
def
forward
(
self
,
M
):
M
=
self
.
layernormM
(
M
)
left_act
=
self
.
linear_a
(
M
)
right_act
=
self
.
linear_b
(
M
)
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act
,
right_act
).
contiguous
()
# O = rearrange(O, 'b i j d e -> b i j (d e)')
O
=
O
.
reshape
(
O
.
shape
[
0
],
O
.
shape
[
1
],
O
.
shape
[
2
],
-
1
)
Z
=
self
.
o_linear
(
O
)
return
Z
class
Linear
(
nn
.
Linear
):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def
__init__
(
self
,
feature_in
:
int
,
feature_out
:
int
,
initializer
:
str
=
'linear'
,
use_bias
:
bool
=
True
,
bias_init
:
float
=
0.
,
):
super
(
Linear
,
self
).
__init__
(
feature_in
,
feature_out
,
bias
=
use_bias
)
self
.
use_bias
=
use_bias
if
initializer
==
'linear'
:
glorot_uniform_af
(
self
.
weight
,
gain
=
1.0
)
elif
initializer
==
'relu'
:
glorot_uniform_af
(
self
.
weight
,
gain
=
2.0
)
elif
initializer
==
'zeros'
:
nn
.
init
.
zeros_
(
self
.
weight
)
if
self
.
use_bias
:
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
bias_init
)
class
SelfAttention
(
nn
.
Module
):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def
__init__
(
self
,
qkv_dim
,
c
,
n_head
,
out_dim
,
gating
=
True
,
last_bias_fuse
=
False
):
super
(
SelfAttention
,
self
).
__init__
()
self
.
qkv_dim
=
qkv_dim
self
.
c
=
c
self
.
n_head
=
n_head
self
.
out_dim
=
out_dim
self
.
gating
=
gating
self
.
last_bias_fuse
=
last_bias_fuse
self
.
scaling
=
self
.
c
**
(
-
0.5
)
# self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear')
self
.
to_q
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
self
.
to_k
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
self
.
to_v
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
if
gating
:
self
.
gating_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
ones
((
n_head
*
c
,)))
self
.
gating_linear
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'zero'
,
use_bias
=
False
)
self
.
o_linear
=
Linear
(
n_head
*
c
,
out_dim
,
initializer
=
'zero'
,
use_bias
=
(
not
last_bias_fuse
))
def
forward
(
self
,
in_data
,
nonbatched_bias
=
None
):
"""
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
# qkv = self.to_qkv(in_data).chunk(3, dim=-1)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q
=
self
.
to_q
(
in_data
)
k
=
self
.
to_k
(
in_data
)
v
=
self
.
to_v
(
in_data
)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
# [q, k, v])
q
,
k
,
v
=
map
(
lambda
t
:
t
.
view
(
t
.
shape
[
0
],
t
.
shape
[
1
],
t
.
shape
[
2
],
self
.
n_head
,
-
1
).
permute
(
0
,
1
,
3
,
2
,
4
),
[
q
,
k
,
v
])
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
if
nonbatched_bias
is
not
None
:
logits
+=
nonbatched_bias
.
unsqueeze
(
1
)
weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
# weights = softmax(logits)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
# weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
weighted_avg
=
weighted_avg
.
permute
(
0
,
1
,
3
,
2
,
4
)
weighted_avg
=
weighted_avg
.
reshape
(
weighted_avg
.
shape
[
0
],
weighted_avg
.
shape
[
1
],
weighted_avg
.
shape
[
2
],
-
1
)
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
output
=
self
.
o_linear
(
weighted_avg
)
return
output
evoformer_openfold/triangle.py
deleted
100644 → 0
View file @
5c4df01a
import
math
import
torch
import
torch.nn
as
nn
from
torch.nn
import
LayerNorm
from
.kernel
import
bias_dropout_add
,
bias_ele_dropout_residual
from
.ops
import
Linear
,
SelfAttention
,
Transition
def
permute_final_dims
(
tensor
,
inds
):
zero_index
=
-
1
*
len
(
inds
)
first_inds
=
list
(
range
(
len
(
tensor
.
shape
[:
zero_index
])))
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
class
TriangleMultiplicationOutgoing
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
128
):
super
(
TriangleMultiplicationOutgoing
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
layernorm1
=
LayerNorm
(
d_pair
)
self
.
left_projection
=
Linear
(
d_pair
,
c
)
self
.
right_projection
=
Linear
(
d_pair
,
c
)
self
.
left_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
right_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
output_gate
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
layernorm2
=
LayerNorm
(
c
)
self
.
output_projection
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
use_bias
=
False
)
self
.
output_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
self
.
p_drop
=
p_drop
def
forward
(
self
,
Z_raw
):
Z
=
self
.
layernorm1
(
Z_raw
)
left_proj_act
=
self
.
left_projection
(
Z
)
right_proj_act
=
self
.
right_projection
(
Z
)
left_proj_act
=
left_proj_act
*
torch
.
sigmoid
(
self
.
left_gate
(
Z
))
right_proj_act
=
right_proj_act
*
torch
.
sigmoid
(
self
.
right_gate
(
Z
))
g
=
torch
.
sigmoid
(
self
.
output_gate
(
Z
))
# p = torch.matmul(
# permute_final_dims(left_proj_act, (2, 0, 1)),
# permute_final_dims(right_proj_act, (2, 1, 0)),
# )
# ab = permute_final_dims(p, (1, 2, 0))
ab
=
torch
.
einsum
(
'bikd,bjkd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
g
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
TriangleMultiplicationIncoming
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
128
):
super
(
TriangleMultiplicationIncoming
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
layernorm1
=
LayerNorm
(
d_pair
)
self
.
left_projection
=
Linear
(
d_pair
,
c
)
self
.
right_projection
=
Linear
(
d_pair
,
c
)
self
.
left_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
right_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
output_gate
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
layernorm2
=
LayerNorm
(
c
)
self
.
output_projection
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
use_bias
=
False
)
self
.
output_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
self
.
p_drop
=
p_drop
def
forward
(
self
,
Z_raw
):
Z
=
self
.
layernorm1
(
Z_raw
)
left_proj_act
=
self
.
left_projection
(
Z
)
right_proj_act
=
self
.
right_projection
(
Z
)
left_proj_act
=
left_proj_act
*
torch
.
sigmoid
(
self
.
left_gate
(
Z
))
right_proj_act
=
right_proj_act
*
torch
.
sigmoid
(
self
.
right_gate
(
Z
))
g
=
torch
.
sigmoid
(
self
.
output_gate
(
Z
))
# p = torch.matmul(
# permute_final_dims(left_proj_act, (2, 1, 0)),
# permute_final_dims(right_proj_act, (2, 0, 1)),
# )
# ab = permute_final_dims(p, (1, 2, 0))
ab
=
torch
.
einsum
(
'bkid,bkjd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
g
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
TriangleAttentionStartingNode
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
32
,
n_head
=
4
):
super
(
TriangleAttentionStartingNode
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernorm1
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
d_pair
,
n_head
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_pair
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_pair
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
def
forward
(
self
,
Z_raw
):
Z
=
self
.
layernorm1
(
Z_raw
)
b
=
torch
.
einsum
(
'bqkc,ch->bhqk'
,
Z
,
self
.
linear_b_weights
)
Z
=
self
.
attention
(
Z
,
b
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
TriangleAttentionEndingNode
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
32
,
n_head
=
4
):
super
(
TriangleAttentionEndingNode
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernorm1
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
d_pair
,
n_head
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_pair
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_pair
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
def
forward
(
self
,
Z_raw
):
Z
=
Z_raw
.
transpose
(
-
2
,
-
3
)
Z
=
self
.
layernorm1
(
Z
)
b
=
torch
.
einsum
(
'bqkc,ch->bhqk'
,
Z
,
self
.
linear_b_weights
)
Z
=
self
.
attention
(
Z
,
b
)
Z
=
Z
.
transpose
(
-
2
,
-
3
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
PairStack
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
=
0.25
):
super
(
PairStack
,
self
).
__init__
()
self
.
TriangleMultiplicationOutgoing
=
TriangleMultiplicationOutgoing
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleMultiplicationIncoming
=
TriangleMultiplicationIncoming
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleAttentionStartingNode
=
TriangleAttentionStartingNode
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleAttentionEndingNode
=
TriangleAttentionEndingNode
(
d_pair
,
p_drop
=
p_drop
)
self
.
PairTransition
=
Transition
(
d
=
d_pair
)
def
forward
(
self
,
pair
):
pair
=
self
.
TriangleMultiplicationOutgoing
(
pair
)
pair
=
self
.
TriangleMultiplicationIncoming
(
pair
)
pair
=
self
.
TriangleAttentionStartingNode
(
pair
)
pair
=
self
.
TriangleAttentionEndingNode
(
pair
)
pair
=
self
.
PairTransition
(
pair
)
return
pair
openfold/evoformer.py
View file @
f7d8092c
...
@@ -284,104 +284,6 @@ class EvoformerBlock(nn.Module):
...
@@ -284,104 +284,6 @@ class EvoformerBlock(nn.Module):
return
m
,
z
return
m
,
z
class
ExtraMSABlock
(
nn
.
Module
):
"""
Almost identical to the standard EvoformerBlock, except in that the
ExtraMSABlock uses GlobalAttention for MSA column attention and
requires more fine-grained control over checkpointing. Separated from
its twin to preserve the TorchScript-ability of the latter.
"""
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
inf
:
float
,
eps
:
float
,
ckpt
:
bool
,
is_multimer
:
bool
,
):
super
(
ExtraMSABlock
,
self
).
__init__
()
self
.
ckpt
=
ckpt
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_att_col
=
MSAColumnGlobalAttention
(
c_in
=
c_m
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
eps
=
eps
,
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
core
=
EvoformerBlockCore
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
)
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
.
clone
(),
z
=
z
.
clone
(),
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
)
)
def
fn
(
m
,
z
):
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
,
z
=
self
.
core
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
)
return
m
,
z
if
(
torch
.
is_grad_enabled
()
and
self
.
ckpt
):
checkpoint_fn
=
get_checkpoint_fn
()
m
,
z
=
checkpoint_fn
(
fn
,
m
,
z
)
else
:
m
,
z
=
fn
(
m
,
z
)
return
m
,
z
class
EvoformerStack
(
nn
.
Module
):
class
EvoformerStack
(
nn
.
Module
):
"""
"""
Main Evoformer trunk.
Main Evoformer trunk.
...
@@ -527,99 +429,3 @@ class EvoformerStack(nn.Module):
...
@@ -527,99 +429,3 @@ class EvoformerStack(nn.Module):
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
return
m
,
z
,
s
class
ExtraMSAStack
(
nn
.
Module
):
"""
Implements Algorithm 18.
"""
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
no_blocks
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
inf
:
float
,
eps
:
float
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
is_multimer
:
bool
=
False
,
**
kwargs
,
):
super
(
ExtraMSAStack
,
self
).
__init__
()
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
no_blocks
):
block
=
ExtraMSABlock
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_msa_att
=
c_hidden_msa_att
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
ckpt
=
ckpt
,
is_multimer
=
is_multimer
,
)
self
.
blocks
.
append
(
block
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for
b
in
self
.
blocks
:
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
if
(
self
.
clear_cache_between_blocks
):
torch
.
cuda
.
empty_cache
()
return
z
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment