Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
69af9310
Commit
69af9310
authored
Dec 29, 2022
by
oahzxl
Browse files
add evoformer openfold init
parent
cb2dd1a1
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
570 additions
and
0 deletions
+570
-0
evoformer_openfold/evoformer.py
evoformer_openfold/evoformer.py
+59
-0
evoformer_openfold/initializer.py
evoformer_openfold/initializer.py
+29
-0
evoformer_openfold/kernel.py
evoformer_openfold/kernel.py
+19
-0
evoformer_openfold/msa.py
evoformer_openfold/msa.py
+95
-0
evoformer_openfold/ops.py
evoformer_openfold/ops.py
+176
-0
evoformer_openfold/triangle.py
evoformer_openfold/triangle.py
+192
-0
No files found.
evoformer_openfold/evoformer.py
0 → 100644
View file @
69af9310
import
torch
import
torch.nn
as
nn
from
.msa
import
MSAStack
from
.ops
import
OutProductMean
from
.triangle
import
PairStack
def
print_memory
(
init_mem
,
text
=
None
):
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
-
init_mem
max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
-
init_mem
print
(
"%s now:%.2f max:%.2f"
%
(
""
if
text
is
None
else
text
,
now_mem
,
max_mem
))
torch
.
cuda
.
reset_peak_memory_stats
()
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
):
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
msa_stack
=
MSAStack
(
d_node
,
d_pair
,
p_drop
=
0.15
)
self
.
communication
=
OutProductMean
(
n_feat
=
d_node
,
n_feat_out
=
d_pair
,
n_feat_proj
=
32
)
self
.
pair_stack
=
PairStack
(
d_pair
=
d_pair
)
def
forward
(
self
,
node
,
pair
):
node
=
self
.
msa_stack
(
node
,
pair
)
pair
=
pair
+
self
.
communication
(
node
)
pair
=
self
.
pair_stack
(
pair
)
return
node
,
pair
class
Evoformer
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
):
super
(
Evoformer
,
self
).
__init__
()
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
1
):
self
.
blocks
.
append
(
EvoformerBlock
(
d_node
,
d_pair
))
def
forward
(
self
,
node
,
pair
):
for
b
in
self
.
blocks
:
node
,
pair
=
b
(
node
,
pair
)
return
node
,
pair
def
evoformer_tiny
():
return
Evoformer
(
d_node
=
64
,
d_pair
=
32
)
def
evoformer_base
():
return
Evoformer
(
d_node
=
256
,
d_pair
=
128
)
def
evoformer_large
():
return
Evoformer
(
d_node
=
512
,
d_pair
=
256
)
__all__
=
[
'Evoformer'
,
'evoformer_base'
,
'evoformer_large'
]
evoformer_openfold/initializer.py
0 → 100755
View file @
69af9310
import
math
import
numpy
as
np
import
torch.nn
as
nn
def
glorot_uniform_af
(
x
,
gain
=
1.0
):
"""
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
In PyTorch:
[feature_out, feature_in, n_head ...]
In Jax:
[... n_head, feature_in, feature_out]
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
[feature_in, n_head, feature_out]
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
"""
fan_in
,
fan_out
=
x
.
shape
[
-
2
:]
if
len
(
x
.
shape
)
>
2
:
receptive_field_size
=
np
.
prod
(
x
.
shape
[:
-
2
])
fan_in
*=
receptive_field_size
fan_out
*=
receptive_field_size
std
=
gain
*
math
.
sqrt
(
2.0
/
float
(
fan_in
+
fan_out
))
dev
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
nn
.
init
.
uniform_
(
x
,
-
dev
,
dev
)
return
x
evoformer_openfold/kernel.py
0 → 100644
View file @
69af9310
import
torch
import
torch.nn.functional
as
F
def
bias_sigmod_ele
(
y
,
bias
,
z
):
return
torch
.
sigmoid
(
y
+
bias
)
*
z
def
bias_dropout_add
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
dropmask
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
out
=
(
x
+
bias
)
*
F
.
dropout
(
dropmask
,
p
=
prob
,
training
=
False
)
out
=
residual
+
out
return
out
def
bias_ele_dropout_residual
(
ab
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
dropout_mask
:
torch
.
Tensor
,
Z_raw
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
Z_raw
+
F
.
dropout
(
dropout_mask
,
p
=
prob
,
training
=
True
)
*
(
g
*
(
ab
+
b
))
\ No newline at end of file
evoformer_openfold/msa.py
0 → 100644
View file @
69af9310
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn
import
LayerNorm
from
.kernel
import
bias_dropout_add
from
.ops
import
SelfAttention
,
Transition
class
MSARowAttentionWithPairBias
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
,
c
=
32
,
n_head
=
8
,
p_drop
=
0.15
):
super
(
MSARowAttentionWithPairBias
,
self
).
__init__
()
self
.
d_node
=
d_node
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernormM
=
LayerNorm
(
d_node
)
self
.
layernormZ
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
n_head
,
d_pair
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
,
requires_grad
=
True
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_node
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_node
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_node
,)),
requires_grad
=
True
)
def
forward
(
self
,
M_raw
,
Z
):
## Input projections
M
=
self
.
layernormM
(
M_raw
)
Z
=
self
.
layernormZ
(
Z
)
b
=
F
.
linear
(
Z
,
self
.
linear_b_weights
)
b
=
b
.
permute
(
0
,
3
,
1
,
2
)
# b = rearrange(b, 'b q k h -> b h q k')
M
=
self
.
attention
(
M
,
b
)
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:]).
to
(
M
.
device
).
to
(
M
.
dtype
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
)
class
MSAColumnAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
c
=
32
,
n_head
=
8
):
super
(
MSAColumnAttention
,
self
).
__init__
()
self
.
d_node
=
d_node
self
.
c
=
c
self
.
n_head
=
n_head
self
.
layernormM
=
LayerNorm
(
d_node
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_node
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_node
,
gating
=
True
)
def
forward
(
self
,
M_raw
):
M
=
M_raw
.
transpose
(
-
2
,
-
3
)
M
=
self
.
layernormM
(
M
)
M
=
self
.
attention
(
M
)
M
=
M
.
transpose
(
-
2
,
-
3
)
return
M_raw
+
M
class
MSAStack
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
,
p_drop
=
0.15
):
super
(
MSAStack
,
self
).
__init__
()
self
.
MSARowAttentionWithPairBias
=
MSARowAttentionWithPairBias
(
d_node
=
d_node
,
d_pair
=
d_pair
,
p_drop
=
p_drop
)
self
.
MSAColumnAttention
=
MSAColumnAttention
(
d_node
=
d_node
)
self
.
MSATransition
=
Transition
(
d
=
d_node
)
def
forward
(
self
,
node
,
pair
):
node
=
self
.
MSARowAttentionWithPairBias
(
node
,
pair
)
node
=
self
.
MSAColumnAttention
(
node
)
node
=
self
.
MSATransition
(
node
)
return
node
evoformer_openfold/ops.py
0 → 100755
View file @
69af9310
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn
import
LayerNorm
from
.initializer
import
glorot_uniform_af
from
.kernel
import
bias_sigmod_ele
class
DropoutRowwise
(
nn
.
Module
):
def
__init__
(
self
,
p
):
super
(
DropoutRowwise
,
self
).
__init__
()
self
.
p
=
p
self
.
dropout
=
nn
.
Dropout
(
p
=
p
)
def
forward
(
self
,
x
):
dropout_mask
=
torch
.
ones_like
(
x
[:,
0
:
1
,
:,
:])
dropout_mask
=
self
.
dropout
(
dropout_mask
)
return
dropout_mask
*
x
class
DropoutColumnwise
(
nn
.
Module
):
def
__init__
(
self
,
p
):
super
(
DropoutColumnwise
,
self
).
__init__
()
self
.
p
=
p
self
.
dropout
=
nn
.
Dropout
(
p
=
p
)
def
forward
(
self
,
x
):
dropout_mask
=
torch
.
ones_like
(
x
[:,
:,
0
:
1
,
:])
dropout_mask
=
self
.
dropout
(
dropout_mask
)
return
dropout_mask
*
x
class
Transition
(
nn
.
Module
):
def
__init__
(
self
,
d
,
n
=
4
):
super
(
Transition
,
self
).
__init__
()
self
.
norm
=
LayerNorm
(
d
)
self
.
linear1
=
Linear
(
d
,
n
*
d
,
initializer
=
'relu'
)
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
def
forward
(
self
,
src
):
x
=
self
.
norm
(
src
)
x
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
x
)))
return
src
+
x
class
OutProductMean
(
nn
.
Module
):
def
__init__
(
self
,
n_feat
=
64
,
n_feat_out
=
128
,
n_feat_proj
=
32
):
super
(
OutProductMean
,
self
).
__init__
()
self
.
layernormM
=
LayerNorm
(
n_feat
)
self
.
linear_a
=
Linear
(
n_feat
,
n_feat_proj
)
self
.
linear_b
=
Linear
(
n_feat
,
n_feat_proj
)
self
.
o_linear
=
Linear
(
n_feat_proj
*
n_feat_proj
,
n_feat_out
,
initializer
=
'zero'
,
use_bias
=
True
)
def
forward
(
self
,
M
):
M
=
self
.
layernormM
(
M
)
left_act
=
self
.
linear_a
(
M
)
right_act
=
self
.
linear_b
(
M
)
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act
,
right_act
).
contiguous
()
# O = rearrange(O, 'b i j d e -> b i j (d e)')
O
=
O
.
reshape
(
O
.
shape
[
0
],
O
.
shape
[
1
],
O
.
shape
[
2
],
-
1
)
Z
=
self
.
o_linear
(
O
)
return
Z
class
Linear
(
nn
.
Linear
):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def
__init__
(
self
,
feature_in
:
int
,
feature_out
:
int
,
initializer
:
str
=
'linear'
,
use_bias
:
bool
=
True
,
bias_init
:
float
=
0.
,
):
super
(
Linear
,
self
).
__init__
(
feature_in
,
feature_out
,
bias
=
use_bias
)
self
.
use_bias
=
use_bias
if
initializer
==
'linear'
:
glorot_uniform_af
(
self
.
weight
,
gain
=
1.0
)
elif
initializer
==
'relu'
:
glorot_uniform_af
(
self
.
weight
,
gain
=
2.0
)
elif
initializer
==
'zeros'
:
nn
.
init
.
zeros_
(
self
.
weight
)
if
self
.
use_bias
:
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
bias_init
)
class
SelfAttention
(
nn
.
Module
):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def
__init__
(
self
,
qkv_dim
,
c
,
n_head
,
out_dim
,
gating
=
True
,
last_bias_fuse
=
False
):
super
(
SelfAttention
,
self
).
__init__
()
self
.
qkv_dim
=
qkv_dim
self
.
c
=
c
self
.
n_head
=
n_head
self
.
out_dim
=
out_dim
self
.
gating
=
gating
self
.
last_bias_fuse
=
last_bias_fuse
self
.
scaling
=
self
.
c
**
(
-
0.5
)
# self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear')
self
.
to_q
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
self
.
to_k
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
self
.
to_v
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
if
gating
:
self
.
gating_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
ones
((
n_head
*
c
,)))
self
.
gating_linear
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'zero'
,
use_bias
=
False
)
self
.
o_linear
=
Linear
(
n_head
*
c
,
out_dim
,
initializer
=
'zero'
,
use_bias
=
(
not
last_bias_fuse
))
def
forward
(
self
,
in_data
,
nonbatched_bias
=
None
):
"""
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
# qkv = self.to_qkv(in_data).chunk(3, dim=-1)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q
=
self
.
to_q
(
in_data
)
k
=
self
.
to_k
(
in_data
)
v
=
self
.
to_v
(
in_data
)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
# [q, k, v])
q
,
k
,
v
=
map
(
lambda
t
:
t
.
view
(
t
.
shape
[
0
],
t
.
shape
[
1
],
t
.
shape
[
2
],
self
.
n_head
,
-
1
).
permute
(
0
,
1
,
3
,
2
,
4
),
[
q
,
k
,
v
])
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
if
nonbatched_bias
is
not
None
:
logits
+=
nonbatched_bias
.
unsqueeze
(
1
)
weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
# weights = softmax(logits)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
# weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
weighted_avg
=
weighted_avg
.
permute
(
0
,
1
,
3
,
2
,
4
)
weighted_avg
=
weighted_avg
.
reshape
(
weighted_avg
.
shape
[
0
],
weighted_avg
.
shape
[
1
],
weighted_avg
.
shape
[
2
],
-
1
)
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
output
=
self
.
o_linear
(
weighted_avg
)
return
output
evoformer_openfold/triangle.py
0 → 100644
View file @
69af9310
import
math
import
torch
import
torch.nn
as
nn
from
torch.nn
import
LayerNorm
from
.kernel
import
bias_dropout_add
,
bias_ele_dropout_residual
from
.ops
import
Linear
,
SelfAttention
,
Transition
def
permute_final_dims
(
tensor
,
inds
):
zero_index
=
-
1
*
len
(
inds
)
first_inds
=
list
(
range
(
len
(
tensor
.
shape
[:
zero_index
])))
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
class
TriangleMultiplicationOutgoing
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
128
):
super
(
TriangleMultiplicationOutgoing
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
layernorm1
=
LayerNorm
(
d_pair
)
self
.
left_projection
=
Linear
(
d_pair
,
c
)
self
.
right_projection
=
Linear
(
d_pair
,
c
)
self
.
left_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
right_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
output_gate
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
layernorm2
=
LayerNorm
(
c
)
self
.
output_projection
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
use_bias
=
False
)
self
.
output_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
self
.
p_drop
=
p_drop
def
forward
(
self
,
Z_raw
):
Z
=
self
.
layernorm1
(
Z_raw
)
left_proj_act
=
self
.
left_projection
(
Z
)
right_proj_act
=
self
.
right_projection
(
Z
)
left_proj_act
=
left_proj_act
*
torch
.
sigmoid
(
self
.
left_gate
(
Z
))
right_proj_act
=
right_proj_act
*
torch
.
sigmoid
(
self
.
right_gate
(
Z
))
g
=
torch
.
sigmoid
(
self
.
output_gate
(
Z
))
# p = torch.matmul(
# permute_final_dims(left_proj_act, (2, 0, 1)),
# permute_final_dims(right_proj_act, (2, 1, 0)),
# )
# ab = permute_final_dims(p, (1, 2, 0))
ab
=
torch
.
einsum
(
'bikd,bjkd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
g
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
TriangleMultiplicationIncoming
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
128
):
super
(
TriangleMultiplicationIncoming
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
layernorm1
=
LayerNorm
(
d_pair
)
self
.
left_projection
=
Linear
(
d_pair
,
c
)
self
.
right_projection
=
Linear
(
d_pair
,
c
)
self
.
left_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
right_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
output_gate
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
layernorm2
=
LayerNorm
(
c
)
self
.
output_projection
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
use_bias
=
False
)
self
.
output_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
self
.
p_drop
=
p_drop
def
forward
(
self
,
Z_raw
):
Z
=
self
.
layernorm1
(
Z_raw
)
left_proj_act
=
self
.
left_projection
(
Z
)
right_proj_act
=
self
.
right_projection
(
Z
)
left_proj_act
=
left_proj_act
*
torch
.
sigmoid
(
self
.
left_gate
(
Z
))
right_proj_act
=
right_proj_act
*
torch
.
sigmoid
(
self
.
right_gate
(
Z
))
g
=
torch
.
sigmoid
(
self
.
output_gate
(
Z
))
# p = torch.matmul(
# permute_final_dims(left_proj_act, (2, 1, 0)),
# permute_final_dims(right_proj_act, (2, 0, 1)),
# )
# ab = permute_final_dims(p, (1, 2, 0))
ab
=
torch
.
einsum
(
'bkid,bkjd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
g
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
TriangleAttentionStartingNode
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
32
,
n_head
=
4
):
super
(
TriangleAttentionStartingNode
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernorm1
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
d_pair
,
n_head
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_pair
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_pair
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
def
forward
(
self
,
Z_raw
):
Z
=
self
.
layernorm1
(
Z_raw
)
b
=
torch
.
einsum
(
'bqkc,ch->bhqk'
,
Z
,
self
.
linear_b_weights
)
Z
=
self
.
attention
(
Z
,
b
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
TriangleAttentionEndingNode
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
32
,
n_head
=
4
):
super
(
TriangleAttentionEndingNode
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernorm1
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
d_pair
,
n_head
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_pair
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_pair
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
def
forward
(
self
,
Z_raw
):
Z
=
Z_raw
.
transpose
(
-
2
,
-
3
)
Z
=
self
.
layernorm1
(
Z
)
b
=
torch
.
einsum
(
'bqkc,ch->bhqk'
,
Z
,
self
.
linear_b_weights
)
Z
=
self
.
attention
(
Z
,
b
)
Z
=
Z
.
transpose
(
-
2
,
-
3
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
PairStack
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
=
0.25
):
super
(
PairStack
,
self
).
__init__
()
self
.
TriangleMultiplicationOutgoing
=
TriangleMultiplicationOutgoing
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleMultiplicationIncoming
=
TriangleMultiplicationIncoming
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleAttentionStartingNode
=
TriangleAttentionStartingNode
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleAttentionEndingNode
=
TriangleAttentionEndingNode
(
d_pair
,
p_drop
=
p_drop
)
self
.
PairTransition
=
Transition
(
d
=
d_pair
)
def
forward
(
self
,
pair
):
pair
=
self
.
TriangleMultiplicationOutgoing
(
pair
)
pair
=
self
.
TriangleMultiplicationIncoming
(
pair
)
pair
=
self
.
TriangleAttentionStartingNode
(
pair
)
pair
=
self
.
TriangleAttentionEndingNode
(
pair
)
pair
=
self
.
PairTransition
(
pair
)
return
pair
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment