Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5c4df01a
Commit
5c4df01a
authored
Dec 29, 2022
by
oahzxl
Browse files
update openfold
parent
289f3a45
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
84 deletions
+12
-84
openfold/evoformer.py
openfold/evoformer.py
+9
-20
openfold/msa.py
openfold/msa.py
+3
-64
No files found.
openfold/evoformer.py
View file @
5c4df01a
...
@@ -182,33 +182,28 @@ class EvoformerBlockCore(nn.Module):
...
@@ -182,33 +182,28 @@ class EvoformerBlockCore(nn.Module):
self
,
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# should be disabled to better approximate the exact activations of
# the original.
# the original.
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
m
=
m
+
self
.
msa_transition
(
m
=
m
+
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
m
,
chunk_size
=
chunk_size
)
)
z
=
z
+
self
.
outer_product_mean
(
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
m
,
chunk_size
=
chunk_size
)
)
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
))
z
=
z
+
self
.
ps_dropout_row_layer
(
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
self
.
tri_att_start
(
z
,
chunk_size
=
chunk_size
)
)
)
z
=
z
+
self
.
ps_dropout_col_layer
(
z
=
z
+
self
.
ps_dropout_col_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
self
.
tri_att_end
(
z
,
chunk_size
=
chunk_size
)
)
)
z
=
z
+
self
.
pair_transition
(
z
=
z
+
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
z
,
chunk_size
=
chunk_size
)
)
return
m
,
z
return
m
,
z
...
@@ -274,22 +269,16 @@ class EvoformerBlock(nn.Module):
...
@@ -274,22 +269,16 @@ class EvoformerBlock(nn.Module):
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
m
+
self
.
msa_dropout_layer
(
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
self
.
msa_att_row
(
m
,
z
=
z
,
chunk_size
=
chunk_size
)
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
=
m
+
self
.
msa_att_col
(
m
,
chunk_size
=
chunk_size
)
m
,
z
=
self
.
core
(
m
,
z
=
self
.
core
(
m
,
m
,
z
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
)
return
m
,
z
return
m
,
z
...
...
openfold/msa.py
View file @
5c4df01a
...
@@ -136,45 +136,6 @@ class MSAAttention(nn.Module):
...
@@ -136,45 +136,6 @@ class MSAAttention(nn.Module):
return
m
,
mask_bias
,
z
return
m
,
mask_bias
,
z
@
torch
.
jit
.
ignore
def
_chunked_msa_attn
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
],
chunk_logits
:
int
,
checkpoint
:
bool
,
)
->
torch
.
Tensor
:
MSA_DIM
=
-
4
def
_get_qkv
(
m
,
z
):
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
return
m
,
q
,
k
,
v
,
mask_bias
,
z
checkpoint_fn
=
get_checkpoint_fn
()
if
(
torch
.
is_grad_enabled
()
and
checkpoint
):
m
,
q
,
k
,
v
,
mask_bias
,
z
=
checkpoint_fn
(
_get_qkv
,
m
,
z
)
else
:
m
,
q
,
k
,
v
,
mask_bias
,
z
=
_get_qkv
(
m
,
z
)
o
=
_attention_chunked_trainable
(
query
=
q
,
key
=
k
,
value
=
v
,
biases
=
[
mask_bias
,
z
],
chunk_size
=
chunk_logits
,
chunk_dim
=
MSA_DIM
,
checkpoint
=
checkpoint
,
)
if
(
torch
.
is_grad_enabled
()
and
checkpoint
):
# Storing an additional m here is far from ideal
m
=
checkpoint_fn
(
self
.
mha
.
_wrap_up
,
o
,
m
)
else
:
m
=
self
.
mha
.
_wrap_up
(
o
,
m
)
return
m
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
...
@@ -199,12 +160,6 @@ class MSAAttention(nn.Module):
...
@@ -199,12 +160,6 @@ class MSAAttention(nn.Module):
cost of slower execution. Chunking is not performed by default.
cost of slower execution. Chunking is not performed by default.
"""
"""
if
(
_chunk_logits
is
not
None
):
return
self
.
_chunked_msa_attn
(
m
=
m
,
z
=
z
,
mask
=
mask
,
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
biases
=
[
mask_bias
]
biases
=
[
mask_bias
]
...
@@ -306,15 +261,11 @@ class MSAColumnAttention(nn.Module):
...
@@ -306,15 +261,11 @@ class MSAColumnAttention(nn.Module):
"""
"""
# [*, N_res, N_seq, C_in]
# [*, N_res, N_seq, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
m
=
self
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
)
m
=
self
.
_msa_att
(
m
,
chunk_size
=
chunk_size
)
# [*, N_seq, N_res, C_in]
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
return
m
return
m
...
@@ -344,12 +295,10 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -344,12 +295,10 @@ class MSAColumnGlobalAttention(nn.Module):
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
mha_input
=
{
mha_input
=
{
"m"
:
m
,
"m"
:
m
,
"mask"
:
mask
,
}
}
return
chunk_layer
(
return
chunk_layer
(
self
.
global_attention
,
self
.
global_attention
,
...
@@ -361,30 +310,20 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -361,30 +310,20 @@ class MSAColumnGlobalAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
if
mask
is
None
:
# [*, N_seq, N_res]
mask
=
torch
.
ones
(
m
.
shape
[:
-
1
],
dtype
=
m
.
dtype
,
device
=
m
.
device
,
).
detach
()
# [*, N_res, N_seq, C_in]
# [*, N_res, N_seq, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
# [*, N_res, N_seq, C_in]
# [*, N_res, N_seq, C_in]
m
=
self
.
layer_norm_m
(
m
)
m
=
self
.
layer_norm_m
(
m
)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
chunk_size
)
else
:
else
:
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
)
m
=
self
.
global_attention
(
m
=
m
)
# [*, N_seq, N_res, C_in]
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment