Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8208fd02
Commit
8208fd02
authored
Jan 18, 2023
by
jiaruifang
Browse files
Merge branch 'main' of
https://github.com/hpcaitech/ColossalAI
into dev0116
parents
438ea608
d565a248
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
300 additions
and
2742 deletions
+300
-2742
tests/test_autochunk/evoformer/ops.py
tests/test_autochunk/evoformer/ops.py
+0
-176
tests/test_autochunk/evoformer/triangle.py
tests/test_autochunk/evoformer/triangle.py
+0
-192
tests/test_autochunk/openfold/checkpointing.py
tests/test_autochunk/openfold/checkpointing.py
+0
-84
tests/test_autochunk/openfold/dropout.py
tests/test_autochunk/openfold/dropout.py
+0
-78
tests/test_autochunk/openfold/evoformer.py
tests/test_autochunk/openfold/evoformer.py
+0
-431
tests/test_autochunk/openfold/msa.py
tests/test_autochunk/openfold/msa.py
+0
-331
tests/test_autochunk/openfold/outer_product_mean.py
tests/test_autochunk/openfold/outer_product_mean.py
+0
-129
tests/test_autochunk/openfold/pair_transition.py
tests/test_autochunk/openfold/pair_transition.py
+0
-99
tests/test_autochunk/openfold/primitives.py
tests/test_autochunk/openfold/primitives.py
+0
-529
tests/test_autochunk/openfold/tensor_utils.py
tests/test_autochunk/openfold/tensor_utils.py
+0
-408
tests/test_autochunk/openfold/triangular_attention.py
tests/test_autochunk/openfold/triangular_attention.py
+0
-139
tests/test_autochunk/openfold/triangular_multiplicative_update.py
...st_autochunk/openfold/triangular_multiplicative_update.py
+0
-127
tests/test_autochunk/test_evoformer_codegen.py
tests/test_autochunk/test_evoformer_codegen.py
+164
-0
tests/test_autochunk/test_simple_evoformer_codegen.py
tests/test_autochunk/test_simple_evoformer_codegen.py
+13
-7
tests/test_autochunk/test_simple_evoformer_search.py
tests/test_autochunk/test_simple_evoformer_search.py
+13
-7
tests/test_tensor/common_utils/_utils.py
tests/test_tensor/common_utils/_utils.py
+12
-5
tests/test_zero/low_level_zero/test_zero_tp.py
tests/test_zero/low_level_zero/test_zero_tp.py
+98
-0
No files found.
tests/test_autochunk/evoformer/ops.py
deleted
100755 → 0
View file @
438ea608
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn
import
LayerNorm
from
.initializer
import
glorot_uniform_af
from
.kernel
import
bias_sigmod_ele
class
DropoutRowwise
(
nn
.
Module
):
def
__init__
(
self
,
p
):
super
(
DropoutRowwise
,
self
).
__init__
()
self
.
p
=
p
self
.
dropout
=
nn
.
Dropout
(
p
=
p
)
def
forward
(
self
,
x
):
dropout_mask
=
torch
.
ones_like
(
x
[:,
0
:
1
,
:,
:])
dropout_mask
=
self
.
dropout
(
dropout_mask
)
return
dropout_mask
*
x
class
DropoutColumnwise
(
nn
.
Module
):
def
__init__
(
self
,
p
):
super
(
DropoutColumnwise
,
self
).
__init__
()
self
.
p
=
p
self
.
dropout
=
nn
.
Dropout
(
p
=
p
)
def
forward
(
self
,
x
):
dropout_mask
=
torch
.
ones_like
(
x
[:,
:,
0
:
1
,
:])
dropout_mask
=
self
.
dropout
(
dropout_mask
)
return
dropout_mask
*
x
class
Transition
(
nn
.
Module
):
def
__init__
(
self
,
d
,
n
=
4
):
super
(
Transition
,
self
).
__init__
()
self
.
norm
=
LayerNorm
(
d
)
self
.
linear1
=
Linear
(
d
,
n
*
d
,
initializer
=
'relu'
)
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
def
forward
(
self
,
src
):
x
=
self
.
norm
(
src
)
x
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
x
)))
return
src
+
x
class
OutProductMean
(
nn
.
Module
):
def
__init__
(
self
,
n_feat
=
64
,
n_feat_out
=
128
,
n_feat_proj
=
32
):
super
(
OutProductMean
,
self
).
__init__
()
self
.
layernormM
=
LayerNorm
(
n_feat
)
self
.
linear_a
=
Linear
(
n_feat
,
n_feat_proj
)
self
.
linear_b
=
Linear
(
n_feat
,
n_feat_proj
)
self
.
o_linear
=
Linear
(
n_feat_proj
*
n_feat_proj
,
n_feat_out
,
initializer
=
'zero'
,
use_bias
=
True
)
def
forward
(
self
,
M
):
M
=
self
.
layernormM
(
M
)
left_act
=
self
.
linear_a
(
M
)
right_act
=
self
.
linear_b
(
M
)
o
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act
,
right_act
).
contiguous
()
# O = rearrange(O, 'b i j d e -> b i j (d e)')
o
=
o
.
reshape
(
o
.
shape
[
0
],
o
.
shape
[
1
],
o
.
shape
[
2
],
-
1
)
Z
=
self
.
o_linear
(
o
)
return
Z
class
Linear
(
nn
.
Linear
):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def
__init__
(
self
,
feature_in
:
int
,
feature_out
:
int
,
initializer
:
str
=
'linear'
,
use_bias
:
bool
=
True
,
bias_init
:
float
=
0.
,
):
super
(
Linear
,
self
).
__init__
(
feature_in
,
feature_out
,
bias
=
use_bias
)
self
.
use_bias
=
use_bias
if
initializer
==
'linear'
:
glorot_uniform_af
(
self
.
weight
,
gain
=
1.0
)
elif
initializer
==
'relu'
:
glorot_uniform_af
(
self
.
weight
,
gain
=
2.0
)
elif
initializer
==
'zeros'
:
nn
.
init
.
zeros_
(
self
.
weight
)
if
self
.
use_bias
:
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
bias_init
)
class
SelfAttention
(
nn
.
Module
):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def
__init__
(
self
,
qkv_dim
,
c
,
n_head
,
out_dim
,
gating
=
True
,
last_bias_fuse
=
False
):
super
(
SelfAttention
,
self
).
__init__
()
self
.
qkv_dim
=
qkv_dim
self
.
c
=
c
self
.
n_head
=
n_head
self
.
out_dim
=
out_dim
self
.
gating
=
gating
self
.
last_bias_fuse
=
last_bias_fuse
self
.
scaling
=
self
.
c
**
(
-
0.5
)
# self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear')
self
.
to_q
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
self
.
to_k
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
self
.
to_v
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
if
gating
:
self
.
gating_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
ones
((
n_head
*
c
,)))
self
.
gating_linear
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'zero'
,
use_bias
=
False
)
self
.
o_linear
=
Linear
(
n_head
*
c
,
out_dim
,
initializer
=
'zero'
,
use_bias
=
(
not
last_bias_fuse
))
def
forward
(
self
,
in_data
,
nonbatched_bias
=
None
):
"""
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
# qkv = self.to_qkv(in_data).chunk(3, dim=-1)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q
=
self
.
to_q
(
in_data
)
k
=
self
.
to_k
(
in_data
)
v
=
self
.
to_v
(
in_data
)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
# [q, k, v])
q
,
k
,
v
=
map
(
lambda
t
:
t
.
view
(
t
.
shape
[
0
],
t
.
shape
[
1
],
t
.
shape
[
2
],
self
.
n_head
,
-
1
).
permute
(
0
,
1
,
3
,
2
,
4
),
[
q
,
k
,
v
])
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
if
nonbatched_bias
is
not
None
:
logits
+=
nonbatched_bias
.
unsqueeze
(
1
)
weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
# weights = softmax(logits)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
# weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
weighted_avg
=
weighted_avg
.
permute
(
0
,
1
,
3
,
2
,
4
)
weighted_avg
=
weighted_avg
.
reshape
(
weighted_avg
.
shape
[
0
],
weighted_avg
.
shape
[
1
],
weighted_avg
.
shape
[
2
],
-
1
)
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
output
=
self
.
o_linear
(
weighted_avg
)
return
output
tests/test_autochunk/evoformer/triangle.py
deleted
100644 → 0
View file @
438ea608
import
math
import
torch
import
torch.nn
as
nn
from
torch.nn
import
LayerNorm
from
.kernel
import
bias_dropout_add
,
bias_ele_dropout_residual
from
.ops
import
Linear
,
SelfAttention
,
Transition
def
permute_final_dims
(
tensor
,
inds
):
zero_index
=
-
1
*
len
(
inds
)
first_inds
=
list
(
range
(
len
(
tensor
.
shape
[:
zero_index
])))
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
class
TriangleMultiplicationOutgoing
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
128
):
super
(
TriangleMultiplicationOutgoing
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
layernorm1
=
LayerNorm
(
d_pair
)
self
.
left_projection
=
Linear
(
d_pair
,
c
)
self
.
right_projection
=
Linear
(
d_pair
,
c
)
self
.
left_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
right_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
output_gate
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
layernorm2
=
LayerNorm
(
c
)
self
.
output_projection
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
use_bias
=
False
)
self
.
output_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
self
.
p_drop
=
p_drop
def
forward
(
self
,
Z_raw
):
Z
=
self
.
layernorm1
(
Z_raw
)
left_proj_act
=
self
.
left_projection
(
Z
)
right_proj_act
=
self
.
right_projection
(
Z
)
left_proj_act
=
left_proj_act
*
torch
.
sigmoid
(
self
.
left_gate
(
Z
))
right_proj_act
=
right_proj_act
*
torch
.
sigmoid
(
self
.
right_gate
(
Z
))
g
=
torch
.
sigmoid
(
self
.
output_gate
(
Z
))
# p = torch.matmul(
# permute_final_dims(left_proj_act, (2, 0, 1)),
# permute_final_dims(right_proj_act, (2, 1, 0)),
# )
# ab = permute_final_dims(p, (1, 2, 0))
ab
=
torch
.
einsum
(
'bikd,bjkd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
g
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
TriangleMultiplicationIncoming
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
128
):
super
(
TriangleMultiplicationIncoming
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
layernorm1
=
LayerNorm
(
d_pair
)
self
.
left_projection
=
Linear
(
d_pair
,
c
)
self
.
right_projection
=
Linear
(
d_pair
,
c
)
self
.
left_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
right_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
output_gate
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
layernorm2
=
LayerNorm
(
c
)
self
.
output_projection
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
use_bias
=
False
)
self
.
output_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
self
.
p_drop
=
p_drop
def
forward
(
self
,
Z_raw
):
Z
=
self
.
layernorm1
(
Z_raw
)
left_proj_act
=
self
.
left_projection
(
Z
)
right_proj_act
=
self
.
right_projection
(
Z
)
left_proj_act
=
left_proj_act
*
torch
.
sigmoid
(
self
.
left_gate
(
Z
))
right_proj_act
=
right_proj_act
*
torch
.
sigmoid
(
self
.
right_gate
(
Z
))
g
=
torch
.
sigmoid
(
self
.
output_gate
(
Z
))
# p = torch.matmul(
# permute_final_dims(left_proj_act, (2, 1, 0)),
# permute_final_dims(right_proj_act, (2, 0, 1)),
# )
# ab = permute_final_dims(p, (1, 2, 0))
ab
=
torch
.
einsum
(
'bkid,bkjd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
g
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
TriangleAttentionStartingNode
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
32
,
n_head
=
4
):
super
(
TriangleAttentionStartingNode
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernorm1
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
d_pair
,
n_head
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_pair
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_pair
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
def
forward
(
self
,
Z_raw
):
Z
=
self
.
layernorm1
(
Z_raw
)
b
=
torch
.
einsum
(
'bqkc,ch->bhqk'
,
Z
,
self
.
linear_b_weights
)
Z
=
self
.
attention
(
Z
,
b
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
TriangleAttentionEndingNode
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
32
,
n_head
=
4
):
super
(
TriangleAttentionEndingNode
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernorm1
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
d_pair
,
n_head
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_pair
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_pair
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
def
forward
(
self
,
Z_raw
):
Z
=
Z_raw
.
transpose
(
-
2
,
-
3
)
Z
=
self
.
layernorm1
(
Z
)
b
=
torch
.
einsum
(
'bqkc,ch->bhqk'
,
Z
,
self
.
linear_b_weights
)
Z
=
self
.
attention
(
Z
,
b
)
Z
=
Z
.
transpose
(
-
2
,
-
3
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:]).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
class
PairStack
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
=
0.25
):
super
(
PairStack
,
self
).
__init__
()
self
.
TriangleMultiplicationOutgoing
=
TriangleMultiplicationOutgoing
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleMultiplicationIncoming
=
TriangleMultiplicationIncoming
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleAttentionStartingNode
=
TriangleAttentionStartingNode
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleAttentionEndingNode
=
TriangleAttentionEndingNode
(
d_pair
,
p_drop
=
p_drop
)
self
.
PairTransition
=
Transition
(
d
=
d_pair
)
def
forward
(
self
,
pair
):
pair
=
self
.
TriangleMultiplicationOutgoing
(
pair
)
pair
=
self
.
TriangleMultiplicationIncoming
(
pair
)
pair
=
self
.
TriangleAttentionStartingNode
(
pair
)
pair
=
self
.
TriangleAttentionEndingNode
(
pair
)
pair
=
self
.
PairTransition
(
pair
)
return
pair
tests/test_autochunk/openfold/checkpointing.py
deleted
100644 → 0
View file @
438ea608
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.utils.checkpoint
from
typing
import
Any
,
Tuple
,
List
,
Callable
,
Optional
BLOCK_ARG
=
Any
BLOCK_ARGS
=
List
[
BLOCK_ARG
]
def
get_checkpoint_fn
():
checkpoint
=
torch
.
utils
.
checkpoint
.
checkpoint
return
checkpoint
@
torch
.
jit
.
ignore
def
checkpoint_blocks
(
blocks
:
List
[
Callable
],
args
:
BLOCK_ARGS
,
blocks_per_ckpt
:
Optional
[
int
],
)
->
BLOCK_ARGS
:
"""
Chunk a list of blocks and run each chunk with activation
checkpointing. We define a "block" as a callable whose only inputs are
the outputs of the previous block.
Implements Subsection 1.11.8
Args:
blocks:
List of blocks
args:
Tuple of arguments for the first block.
blocks_per_ckpt:
Size of each chunk. A higher value corresponds to fewer
checkpoints, and trades memory for speed. If None, no checkpointing
is performed.
Returns:
The output of the final block
"""
def
wrap
(
a
):
return
(
a
,)
if
type
(
a
)
is
not
tuple
else
a
def
exec
(
b
,
a
):
for
block
in
b
:
a
=
wrap
(
block
(
*
a
))
return
a
def
chunker
(
s
,
e
):
def
exec_sliced
(
*
a
):
return
exec
(
blocks
[
s
:
e
],
a
)
return
exec_sliced
# Avoids mishaps when the blocks take just one argument
args
=
wrap
(
args
)
if
blocks_per_ckpt
is
None
:
return
exec
(
blocks
,
args
)
elif
blocks_per_ckpt
<
1
or
blocks_per_ckpt
>
len
(
blocks
):
raise
ValueError
(
"blocks_per_ckpt must be between 1 and len(blocks)"
)
checkpoint
=
get_checkpoint_fn
()
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
args
=
checkpoint
(
chunker
(
s
,
e
),
*
args
)
args
=
wrap
(
args
)
return
args
tests/test_autochunk/openfold/dropout.py
deleted
100644 → 0
View file @
438ea608
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
from
functools
import
partialmethod
from
typing
import
Union
,
List
class
Dropout
(
nn
.
Module
):
"""
Implementation of dropout with the ability to share the dropout mask
along a particular dimension.
If not in training mode, this module computes the identity function.
"""
def
__init__
(
self
,
r
:
float
,
batch_dim
:
Union
[
int
,
List
[
int
]]):
"""
Args:
r:
Dropout rate
batch_dim:
Dimension(s) along which the dropout mask is shared
"""
super
(
Dropout
,
self
).
__init__
()
self
.
r
=
r
if
type
(
batch_dim
)
==
int
:
batch_dim
=
[
batch_dim
]
self
.
batch_dim
=
batch_dim
self
.
dropout
=
nn
.
Dropout
(
self
.
r
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
x:
Tensor to which dropout is applied. Can have any shape
compatible with self.batch_dim
"""
shape
=
list
(
x
.
shape
)
if
self
.
batch_dim
is
not
None
:
for
bd
in
self
.
batch_dim
:
shape
[
bd
]
=
1
mask
=
x
.
new_ones
(
shape
)
mask
=
self
.
dropout
(
mask
)
x
*=
mask
return
x
class
DropoutRowwise
(
Dropout
):
"""
Convenience class for rowwise dropout as described in subsection
1.11.6.
"""
__init__
=
partialmethod
(
Dropout
.
__init__
,
batch_dim
=-
3
)
class
DropoutColumnwise
(
Dropout
):
"""
Convenience class for columnwise dropout as described in subsection
1.11.6.
"""
__init__
=
partialmethod
(
Dropout
.
__init__
,
batch_dim
=-
2
)
tests/test_autochunk/openfold/evoformer.py
deleted
100644 → 0
View file @
438ea608
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
Optional
from
functools
import
partial
from
.primitives
import
Linear
,
LayerNorm
from
.dropout
import
DropoutRowwise
,
DropoutColumnwise
from
.msa
import
(
MSARowAttentionWithPairBias
,
MSAColumnAttention
,
MSAColumnGlobalAttention
,
)
from
.outer_product_mean
import
OuterProductMean
from
.pair_transition
import
PairTransition
from
.triangular_attention
import
(
TriangleAttentionStartingNode
,
TriangleAttentionEndingNode
,
)
from
.triangular_multiplicative_update
import
(
TriangleMultiplicationOutgoing
,
TriangleMultiplicationIncoming
,
)
from
.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
.tensor_utils
import
chunk_layer
class
MSATransition
(
nn
.
Module
):
"""
Feed-forward network applied to MSA activations after attention.
Implements Algorithm 9
"""
def
__init__
(
self
,
c_m
,
n
):
"""
Args:
c_m:
MSA channel dimension
n:
Factor multiplied to c_m to obtain the hidden channel
dimension
"""
super
(
MSATransition
,
self
).
__init__
()
self
.
c_m
=
c_m
self
.
n
=
n
self
.
layer_norm
=
LayerNorm
(
self
.
c_m
)
self
.
linear_1
=
Linear
(
self
.
c_m
,
self
.
n
*
self
.
c_m
,
init
=
"relu"
)
self
.
relu
=
nn
.
ReLU
()
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_m
,
self
.
c_m
,
init
=
"final"
)
def
_transition
(
self
,
m
,
mask
):
m
=
self
.
linear_1
(
m
)
m
=
self
.
relu
(
m
)
m
=
self
.
linear_2
(
m
)
*
mask
return
m
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
return
chunk_layer
(
self
.
_transition
,
{
"m"
:
m
,
"mask"
:
mask
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA activation
mask:
[*, N_seq, N_res, C_m] MSA mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA activation update
"""
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if
mask
is
None
:
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
# [*, N_seq, N_res, 1]
mask
=
mask
.
unsqueeze
(
-
1
)
m
=
self
.
layer_norm
(
m
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
else
:
m
=
self
.
_transition
(
m
,
mask
)
return
m
class
EvoformerBlockCore
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
pair_dropout
:
float
,
inf
:
float
,
eps
:
float
,
_is_extra_msa_stack
:
bool
=
False
,
is_multimer
:
bool
=
False
,
):
super
(
EvoformerBlockCore
,
self
).
__init__
()
self
.
is_multimer
=
is_multimer
self
.
msa_transition
=
MSATransition
(
c_m
=
c_m
,
n
=
transition_n
,
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_z
,
c_hidden_opm
,
)
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
c_z
,
c_hidden_mul
,
)
self
.
tri_mul_in
=
TriangleMultiplicationIncoming
(
c_z
,
c_hidden_mul
,
)
self
.
tri_att_start
=
TriangleAttentionStartingNode
(
c_z
,
c_hidden_pair_att
,
no_heads_pair
,
inf
=
inf
,
)
self
.
tri_att_end
=
TriangleAttentionEndingNode
(
c_z
,
c_hidden_pair_att
,
no_heads_pair
,
inf
=
inf
,
)
self
.
pair_transition
=
PairTransition
(
c_z
,
transition_n
,
)
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_col_layer
=
DropoutColumnwise
(
pair_dropout
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
m
=
m
+
self
.
msa_transition
(
m
,
chunk_size
=
chunk_size
)
z
=
z
+
self
.
outer_product_mean
(
m
,
chunk_size
=
chunk_size
)
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
chunk_size
=
chunk_size
)
)
z
=
z
+
self
.
ps_dropout_col_layer
(
self
.
tri_att_end
(
z
,
chunk_size
=
chunk_size
)
)
z
=
z
+
self
.
pair_transition
(
z
,
chunk_size
=
chunk_size
)
return
m
,
z
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
inf
:
float
,
eps
:
float
,
is_multimer
:
bool
,
):
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_att_col
=
MSAColumnAttention
(
c_m
,
c_hidden_msa_att
,
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
core
=
EvoformerBlockCore
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_z
,
c_hidden_opm
,
)
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
chunk_size
=
chunk_size
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
chunk_size
=
chunk_size
)
m
,
z
=
self
.
core
(
m
,
z
,
chunk_size
=
chunk_size
,
)
return
m
,
z
class
EvoformerStack
(
nn
.
Module
):
"""
Main Evoformer trunk.
Implements Algorithm 6.
"""
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
c_s
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
no_blocks
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
blocks_per_ckpt
:
int
,
inf
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
is_multimer
:
bool
=
False
,
**
kwargs
,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair channel dimension
c_hidden_msa_att:
Hidden dimension in MSA attention
c_hidden_opm:
Hidden dimension in outer product mean module
c_hidden_mul:
Hidden dimension in multiplicative updates
c_hidden_pair_att:
Hidden dimension in triangular attention
c_s:
Channel dimension of the output "single" embedding
no_heads_msa:
Number of heads used for MSA attention
no_heads_pair:
Number of heads used for pair attention
no_blocks:
Number of Evoformer blocks in the stack
transition_n:
Factor by which to multiply c_m to obtain the MSATransition
hidden dimension
msa_dropout:
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
"""
super
(
EvoformerStack
,
self
).
__init__
()
self
.
blocks_per_ckpt
=
blocks_per_ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
no_blocks
):
block
=
EvoformerBlock
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_msa_att
=
c_hidden_msa_att
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
is_multimer
=
is_multimer
,
)
self
.
blocks
.
append
(
block
)
self
.
linear
=
Linear
(
c_m
,
c_s
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
if
(
self
.
clear_cache_between_blocks
):
def
block_with_cache_clear
(
block
,
*
args
):
torch
.
cuda
.
empty_cache
()
return
block
(
*
args
)
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
m
,
z
=
checkpoint_blocks
(
blocks
,
args
=
(
m
,
z
),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
tests/test_autochunk/openfold/msa.py
deleted
100644 → 0
View file @
438ea608
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
torch
import
torch.nn
as
nn
from
typing
import
Optional
,
List
,
Tuple
from
.primitives
import
(
Linear
,
LayerNorm
,
Attention
,
GlobalAttention
,
_attention_chunked_trainable
,
)
from
.checkpointing
import
get_checkpoint_fn
from
.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
flatten_final_dims
,
)
class
MSAAttention
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
pair_bias
=
False
,
c_z
=
None
,
inf
=
1e9
,
):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
pair_bias:
Whether to use pair embedding bias
c_z:
Pair embedding channel dimension. Ignored unless pair_bias
is true
inf:
A large number to be used in computing the attention mask
"""
super
(
MSAAttention
,
self
).
__init__
()
self
.
c_in
=
c_in
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
pair_bias
=
pair_bias
self
.
c_z
=
c_z
self
.
inf
=
inf
self
.
layer_norm_m
=
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_z
=
None
self
.
linear_z
=
None
if
self
.
pair_bias
:
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
self
.
linear_z
=
Linear
(
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
self
.
mha
=
Attention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
)
->
torch
.
Tensor
:
return
chunk_layer
(
self
.
mha
,
{
"q_x"
:
m
,
"kv_x"
:
m
,
"biases"
:
biases
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
_prep_inputs
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm_m
(
m
)
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
if
mask
is
None
:
# [*, N_seq, N_res]
mask
=
m
.
new_ones
(
m
.
shape
[:
-
3
]
+
(
n_seq
,
n_res
),
)
# [*, N_seq, 1, 1, N_res]
mask_bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
# [*, N_seq, no_heads, N_res, N_res]
#bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#)
if
(
self
.
pair_bias
and
z
is
not
None
and
# For the
self
.
layer_norm_z
is
not
None
and
# benefit of
self
.
linear_z
is
not
None
# TorchScript
):
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm_z
(
z
)
# [*, N_res, N_res, no_heads]
z
=
self
.
linear_z
(
z
)
# [*, 1, no_heads, N_res, N_res]
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)).
unsqueeze
(
-
4
)
return
m
,
mask_bias
,
z
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
biases
.
append
(
z
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
)
else
:
m
=
self
.
mha
(
q_x
=
m
,
kv_x
=
m
,
biases
=
biases
)
return
m
class
MSARowAttentionWithPairBias
(
MSAAttention
):
"""
Implements Algorithm 7.
"""
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
no_heads
,
inf
=
1e9
):
"""
Args:
c_m:
Input channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
inf:
Large number used to construct attention masks
"""
super
(
MSARowAttentionWithPairBias
,
self
).
__init__
(
c_m
,
c_hidden
,
no_heads
,
pair_bias
=
True
,
c_z
=
c_z
,
inf
=
inf
,
)
class
MSAColumnAttention
(
nn
.
Module
):
"""
Implements Algorithm 8.
By rights, this should also be a subclass of MSAAttention. Alas,
most inheritance isn't supported by TorchScript.
"""
def
__init__
(
self
,
c_m
,
c_hidden
,
no_heads
,
inf
=
1e9
):
"""
Args:
c_m:
MSA channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
inf:
Large number used to construct attention masks
"""
super
(
MSAColumnAttention
,
self
).
__init__
()
self
.
c_m
=
c_m
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
inf
=
inf
self
.
_msa_att
=
MSAAttention
(
c_in
=
c_m
,
c_hidden
=
c_hidden
,
no_heads
=
no_heads
,
pair_bias
=
False
,
c_z
=
None
,
inf
=
inf
,
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
# [*, N_res, N_seq, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
self
.
_msa_att
(
m
,
chunk_size
=
chunk_size
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
return
m
class
MSAColumnGlobalAttention
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
inf
=
1e9
,
eps
=
1e-10
,
):
super
(
MSAColumnGlobalAttention
,
self
).
__init__
()
self
.
c_in
=
c_in
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
inf
=
inf
self
.
eps
=
eps
self
.
layer_norm_m
=
nn
.
LayerNorm
(
c_in
)
self
.
global_attention
=
GlobalAttention
(
c_in
=
c_in
,
c_hidden
=
c_hidden
,
no_heads
=
no_heads
,
inf
=
inf
,
eps
=
eps
,
)
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
mha_input
=
{
"m"
:
m
,
}
return
chunk_layer
(
self
.
global_attention
,
mha_input
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
# [*, N_res, N_seq, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
# [*, N_res, N_seq, C_in]
m
=
self
.
layer_norm_m
(
m
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
chunk_size
)
else
:
m
=
self
.
global_attention
(
m
=
m
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
return
m
tests/test_autochunk/openfold/outer_product_mean.py
deleted
100644 → 0
View file @
438ea608
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
.primitives
import
Linear
from
.tensor_utils
import
chunk_layer
class
OuterProductMean
(
nn
.
Module
):
"""
Implements Algorithm 10.
"""
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
eps
=
1e-3
):
"""
Args:
c_m:
MSA embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
"""
super
(
OuterProductMean
,
self
).
__init__
()
self
.
c_m
=
c_m
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
eps
=
eps
self
.
layer_norm
=
nn
.
LayerNorm
(
c_m
)
self
.
linear_1
=
Linear
(
c_m
,
c_hidden
)
self
.
linear_2
=
Linear
(
c_m
,
c_hidden
)
self
.
linear_out
=
Linear
(
c_hidden
**
2
,
c_z
,
init
=
"final"
)
def
_opm
(
self
,
a
,
b
):
# [*, N_res, N_res, C, C]
outer
=
torch
.
einsum
(
"...bac,...dae->...bdce"
,
a
,
b
)
# [*, N_res, N_res, C * C]
outer
=
outer
.
reshape
(
outer
.
shape
[:
-
2
]
+
(
-
1
,))
# [*, N_res, N_res, C_z]
outer
=
self
.
linear_out
(
outer
)
return
outer
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
a_reshape
=
a
.
reshape
((
-
1
,)
+
a
.
shape
[
-
3
:])
b_reshape
=
b
.
reshape
((
-
1
,)
+
b
.
shape
[
-
3
:])
out
=
[]
for
a_prime
,
b_prime
in
zip
(
a_reshape
,
b_reshape
):
outer
=
chunk_layer
(
partial
(
self
.
_opm
,
b
=
b_prime
),
{
"a"
:
a_prime
},
chunk_size
=
chunk_size
,
no_batch_dims
=
1
,
)
out
.
append
(
outer
)
outer
=
torch
.
stack
(
out
,
dim
=
0
)
outer
=
outer
.
reshape
(
a
.
shape
[:
-
3
]
+
outer
.
shape
[
1
:])
return
outer
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
if
mask
is
None
:
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm
(
m
)
# [*, N_seq, N_res, C]
mask
=
mask
.
unsqueeze
(
-
1
)
a
=
self
.
linear_1
(
m
)
*
mask
b
=
self
.
linear_2
(
m
)
*
mask
a
=
a
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
if
chunk_size
is
not
None
:
outer
=
self
.
_chunk
(
a
,
b
,
chunk_size
)
else
:
outer
=
self
.
_opm
(
a
,
b
)
# [*, N_res, N_res, 1]
norm
=
torch
.
einsum
(
"...abc,...adc->...bdc"
,
mask
,
mask
)
# [*, N_res, N_res, C_z]
outer
=
outer
/
(
self
.
eps
+
norm
)
return
outer
tests/test_autochunk/openfold/pair_transition.py
deleted
100644 → 0
View file @
438ea608
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
.primitives
import
Linear
,
LayerNorm
from
.tensor_utils
import
chunk_layer
class
PairTransition
(
nn
.
Module
):
"""
Implements Algorithm 15.
"""
def
__init__
(
self
,
c_z
,
n
):
"""
Args:
c_z:
Pair transition channel dimension
n:
Factor by which c_z is multiplied to obtain hidden channel
dimension
"""
super
(
PairTransition
,
self
).
__init__
()
self
.
c_z
=
c_z
self
.
n
=
n
self
.
layer_norm
=
LayerNorm
(
self
.
c_z
)
self
.
linear_1
=
Linear
(
self
.
c_z
,
self
.
n
*
self
.
c_z
,
init
=
"relu"
)
self
.
relu
=
nn
.
ReLU
()
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_z
,
c_z
,
init
=
"final"
)
def
_transition
(
self
,
z
,
mask
):
# [*, N_res, N_res, C_hidden]
z
=
self
.
linear_1
(
z
)
z
=
self
.
relu
(
z
)
# [*, N_res, N_res, C_z]
z
=
self
.
linear_2
(
z
)
*
mask
return
z
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
return
chunk_layer
(
self
.
_transition
,
{
"z"
:
z
,
"mask"
:
mask
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
z:
[*, N_res, N_res, C_z] pair embedding
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
# [*, N_res, N_res, 1]
mask
=
mask
.
unsqueeze
(
-
1
)
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm
(
z
)
if
chunk_size
is
not
None
:
z
=
self
.
_chunk
(
z
,
mask
,
chunk_size
)
else
:
z
=
self
.
_transition
(
z
=
z
,
mask
=
mask
)
return
z
tests/test_autochunk/openfold/primitives.py
deleted
100644 → 0
View file @
438ea608
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
.checkpointing
import
get_checkpoint_fn
from
.tensor_utils
import
(
permute_final_dims
,
flatten_final_dims
,
_chunk_slice
,
)
def
_prod
(
nums
):
out
=
1
for
n
in
nums
:
out
=
out
*
n
return
out
def
_calculate_fan
(
linear_weight_shape
,
fan
=
"fan_in"
):
fan_out
,
fan_in
=
linear_weight_shape
if
fan
==
"fan_in"
:
f
=
fan_in
elif
fan
==
"fan_out"
:
f
=
fan_out
elif
fan
==
"fan_avg"
:
f
=
(
fan_in
+
fan_out
)
/
2
else
:
raise
ValueError
(
"Invalid fan option"
)
return
f
def
glorot_uniform_init_
(
weights
):
nn
.
init
.
xavier_uniform_
(
weights
,
gain
=
1
)
def
final_init_
(
weights
):
with
torch
.
no_grad
():
weights
.
fill_
(
0.0
)
def
gating_init_
(
weights
):
with
torch
.
no_grad
():
weights
.
fill_
(
0.0
)
def
normal_init_
(
weights
):
torch
.
nn
.
init
.
kaiming_normal_
(
weights
,
nonlinearity
=
"linear"
)
def
ipa_point_weights_init_
(
weights
):
with
torch
.
no_grad
():
softplus_inverse_1
=
0.541324854612918
weights
.
fill_
(
softplus_inverse_1
)
class
Linear
(
nn
.
Linear
):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def
__init__
(
self
,
in_dim
:
int
,
out_dim
:
int
,
bias
:
bool
=
True
,
init
:
str
=
"default"
,
init_fn
:
Optional
[
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
],
None
]]
=
None
,
):
"""
Args:
in_dim:
The final dimension of inputs to the layer
out_dim:
The final dimension of layer outputs
bias:
Whether to learn an additive bias. True by default
init:
The initializer to use. Choose from:
"default": LeCun fan-in truncated normal initialization
"relu": He initialization w/ truncated normal distribution
"glorot": Fan-average Glorot uniform initialization
"gating": Weights=0, Bias=1
"normal": Normal initialization with std=1/sqrt(fan_in)
"final": Weights=0, Bias=0
Overridden by init_fn if the latter is not None.
init_fn:
A custom initializer taking weight and bias as inputs.
Overrides init if not None.
"""
super
(
Linear
,
self
).
__init__
(
in_dim
,
out_dim
,
bias
=
bias
)
if
bias
:
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
0
)
if
init_fn
is
not
None
:
init_fn
(
self
.
weight
,
self
.
bias
)
else
:
if
init
==
"default"
:
normal_init_
(
self
.
weight
)
elif
init
==
"relu"
:
normal_init_
(
self
.
weight
)
elif
init
==
"glorot"
:
glorot_uniform_init_
(
self
.
weight
)
elif
init
==
"gating"
:
gating_init_
(
self
.
weight
)
if
bias
:
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
1.0
)
elif
init
==
"normal"
:
normal_init_
(
self
.
weight
)
elif
init
==
"final"
:
final_init_
(
self
.
weight
)
else
:
raise
ValueError
(
"Invalid init string."
)
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
eps
=
1e-5
):
super
(
LayerNorm
,
self
).
__init__
()
self
.
c_in
=
(
c_in
,)
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
c_in
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
c_in
))
def
forward
(
self
,
x
):
out
=
nn
.
functional
.
layer_norm
(
x
,
self
.
c_in
,
self
.
weight
,
self
.
bias
,
self
.
eps
,
)
return
out
@
torch
.
jit
.
ignore
def
softmax
(
t
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
"""
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
return
s
#@torch.jit.script
def
_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# [*, H, Q, C_hidden]
query
=
permute_final_dims
(
query
,
(
1
,
0
,
2
))
# [*, H, C_hidden, K]
key
=
permute_final_dims
(
key
,
(
1
,
2
,
0
))
# [*, H, V, C_hidden]
value
=
permute_final_dims
(
value
,
(
1
,
0
,
2
))
# [*, H, Q, K]
a
=
torch
.
matmul
(
query
,
key
)
for
b
in
biases
:
a
+=
b
a
=
softmax
(
a
,
-
1
)
# [*, H, Q, C_hidden]
a
=
torch
.
matmul
(
a
,
value
)
# [*, Q, H, C_hidden]
a
=
a
.
transpose
(
-
2
,
-
3
)
return
a
@
torch
.
jit
.
ignore
def
_attention_chunked_trainable
(
query
,
key
,
value
,
biases
,
chunk_size
,
chunk_dim
,
checkpoint
,
):
if
(
checkpoint
and
len
(
biases
)
>
2
):
raise
ValueError
(
"Checkpointed version permits only permits two bias terms"
)
def
_checkpointable_attention
(
q
,
k
,
v
,
b1
,
b2
):
bs
=
[
b
for
b
in
[
b1
,
b2
]
if
b
is
not
None
]
return
_attention
(
q
,
k
,
v
,
bs
)
o_chunks
=
[]
checkpoint_fn
=
get_checkpoint_fn
()
count
=
query
.
shape
[
chunk_dim
]
for
start
in
range
(
0
,
count
,
chunk_size
):
end
=
start
+
chunk_size
idx
=
[
slice
(
None
)]
*
len
(
query
.
shape
)
idx
[
chunk_dim
]
=
slice
(
start
,
end
)
idx_tup
=
tuple
(
idx
)
q_chunk
=
query
[
idx_tup
]
k_chunk
=
key
[
idx_tup
]
v_chunk
=
value
[
idx_tup
]
def
_slice_bias
(
b
):
idx
[
chunk_dim
]
=
(
slice
(
start
,
end
)
if
b
.
shape
[
chunk_dim
]
!=
1
else
slice
(
None
))
return
b
[
tuple
(
idx
)]
if
(
checkpoint
):
bias_1_chunk
,
bias_2_chunk
=
[
_slice_bias
(
b
)
if
b
is
not
None
else
None
for
b
in
(
biases
+
[
None
,
None
])[:
2
]
]
o_chunk
=
checkpoint_fn
(
_checkpointable_attention
,
q_chunk
,
k_chunk
,
v_chunk
,
bias_1_chunk
,
bias_2_chunk
)
else
:
bias_chunks
=
[
_slice_bias
(
b
)
for
b
in
biases
]
o_chunk
=
_attention
(
q_chunk
,
k_chunk
,
v_chunk
,
bias_chunks
)
o_chunks
.
append
(
o_chunk
)
o
=
torch
.
cat
(
o_chunks
,
dim
=
chunk_dim
)
return
o
class
Attention
(
nn
.
Module
):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors.
"""
def
__init__
(
self
,
c_q
:
int
,
c_k
:
int
,
c_v
:
int
,
c_hidden
:
int
,
no_heads
:
int
,
gating
:
bool
=
True
,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
"""
super
(
Attention
,
self
).
__init__
()
self
.
c_q
=
c_q
self
.
c_k
=
c_k
self
.
c_v
=
c_v
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
gating
=
gating
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
# stated in the supplement, but the overall channel dimension.
self
.
linear_q
=
Linear
(
self
.
c_q
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_k
=
Linear
(
self
.
c_k
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_v
=
Linear
(
self
.
c_v
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_o
=
Linear
(
self
.
c_hidden
*
self
.
no_heads
,
self
.
c_q
,
init
=
"final"
)
self
.
linear_g
=
None
if
self
.
gating
:
self
.
linear_g
=
Linear
(
self
.
c_q
,
self
.
c_hidden
*
self
.
no_heads
,
init
=
"gating"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
def
_prep_qkv
(
self
,
q_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
kv_x
)
v
=
self
.
linear_v
(
kv_x
)
# [*, Q/K, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
/=
math
.
sqrt
(
self
.
c_hidden
)
return
q
,
k
,
v
def
_wrap_up
(
self
,
o
:
torch
.
Tensor
,
q_x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
(
self
.
linear_g
is
not
None
):
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
o
*
g
# [*, Q, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
# [*, Q, C_q]
o
=
self
.
linear_o
(
o
)
return
o
def
forward
(
self
,
q_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_lma
:
bool
=
False
,
q_chunk_size
:
Optional
[
int
]
=
None
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
q_x:
[*, Q, C_q] query data
kv_x:
[*, K, C_k] key data
biases:
List of biases that broadcast to [*, H, Q, K]
use_lma:
Whether to use low-memory attention
q_chunk_size:
Query chunk size (for LMA)
kv_chunk_size:
Key/Value chunk size (for LMA)
Returns
[*, Q, C_q] attention update
"""
if
(
biases
is
None
):
biases
=
[]
if
(
use_lma
and
(
q_chunk_size
is
None
or
kv_chunk_size
is
None
)):
raise
ValueError
(
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
)
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
if
(
use_lma
):
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
for
b
in
biases
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
self
.
_wrap_up
(
o
,
q_x
)
return
o
class
GlobalAttention
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
inf
,
eps
):
super
(
GlobalAttention
,
self
).
__init__
()
self
.
c_in
=
c_in
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
inf
=
inf
self
.
eps
=
eps
self
.
linear_q
=
Linear
(
c_in
,
c_hidden
*
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_k
=
Linear
(
c_in
,
c_hidden
,
bias
=
False
,
init
=
"glorot"
,
)
self
.
linear_v
=
Linear
(
c_in
,
c_hidden
,
bias
=
False
,
init
=
"glorot"
,
)
self
.
linear_g
=
Linear
(
c_in
,
c_hidden
*
no_heads
,
init
=
"gating"
)
self
.
linear_o
=
Linear
(
c_hidden
*
no_heads
,
c_in
,
init
=
"final"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# [*, N_res, C_in]
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
)
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
q
)
q
*=
(
self
.
c_hidden
**
(
-
0.5
))
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, N_seq, C_hidden]
k
=
self
.
linear_k
(
m
)
v
=
self
.
linear_v
(
m
)
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+=
bias
a
=
softmax
(
a
)
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
a
,
v
,
)
# [*, N_res, N_seq, C_hidden]
g
=
self
.
sigmoid
(
self
.
linear_g
(
m
))
# [*, N_res, N_seq, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, N_seq, H, C_hidden]
o
=
o
.
unsqueeze
(
-
3
)
*
g
# [*, N_res, N_seq, H * C_hidden]
o
=
o
.
reshape
(
o
.
shape
[:
-
2
]
+
(
-
1
,))
# [*, N_res, N_seq, C_in]
m
=
self
.
linear_o
(
o
)
return
m
def
_lma
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
q_chunk_size
:
int
,
kv_chunk_size
:
int
,
):
no_q
,
no_kv
=
q
.
shape
[
-
3
],
k
.
shape
[
-
3
]
# [*, Q, H, C_hidden]
o
=
q
.
new_zeros
(
q
.
shape
)
for
q_s
in
range
(
0
,
no_q
,
q_chunk_size
):
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
large_bias_chunks
=
[
b
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
for
b
in
biases
]
maxes
=
[]
weights
=
[]
values
=
[]
for
kv_s
in
range
(
0
,
no_kv
,
kv_chunk_size
):
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
small_bias_chunks
=
[
b
[...,
kv_s
:
kv_s
+
kv_chunk_size
]
for
b
in
large_bias_chunks
]
a
=
torch
.
einsum
(
"...qhd,...khd->...hqk"
,
q_chunk
,
k_chunk
,
)
for
b
in
small_bias_chunks
:
a
+=
b
a
=
a
.
transpose
(
-
2
,
-
3
)
max_a
=
torch
.
max
(
a
,
dim
=-
1
,
keepdim
=
True
)[
0
]
exp_a
=
torch
.
exp
(
a
-
max_a
)
exp_v
=
torch
.
einsum
(
"...vhf,...qhv->...qhf"
,
v_chunk
,
exp_a
)
maxes
.
append
(
max_a
.
detach
().
squeeze
(
-
1
))
weights
.
append
(
torch
.
sum
(
exp_a
,
dim
=-
1
))
values
.
append
(
exp_v
)
chunk_max
=
torch
.
stack
(
maxes
,
dim
=-
3
)
chunk_weights
=
torch
.
stack
(
weights
,
dim
=-
3
)
chunk_values
=
torch
.
stack
(
values
,
dim
=-
4
)
global_max
=
torch
.
max
(
chunk_max
,
dim
=-
3
,
keepdim
=
True
)[
0
]
max_diffs
=
torch
.
exp
(
chunk_max
-
global_max
)
chunk_values
*=
max_diffs
.
unsqueeze
(
-
1
)
chunk_weights
*=
max_diffs
all_values
=
torch
.
sum
(
chunk_values
,
dim
=-
4
)
all_weights
=
torch
.
sum
(
chunk_weights
.
unsqueeze
(
-
1
),
dim
=-
4
)
q_chunk_out
=
all_values
/
all_weights
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
=
q_chunk_out
return
o
tests/test_autochunk/openfold/tensor_utils.py
deleted
100644 → 0
View file @
438ea608
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
def
permute_final_dims
(
tensor
:
torch
.
Tensor
,
inds
:
List
[
int
]):
zero_index
=
-
1
*
len
(
inds
)
first_inds
=
list
(
range
(
len
(
tensor
.
shape
[:
zero_index
])))
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
def
flatten_final_dims
(
t
:
torch
.
Tensor
,
no_dims
:
int
):
return
t
.
reshape
(
t
.
shape
[:
-
no_dims
]
+
(
-
1
,))
def
masked_mean
(
mask
,
value
,
dim
,
eps
=
1e-4
):
mask
=
mask
.
expand
(
*
value
.
shape
)
return
torch
.
sum
(
mask
*
value
,
dim
=
dim
)
/
(
eps
+
torch
.
sum
(
mask
,
dim
=
dim
))
def
pts_to_distogram
(
pts
,
min_bin
=
2.3125
,
max_bin
=
21.6875
,
no_bins
=
64
):
boundaries
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
-
1
,
device
=
pts
.
device
)
dists
=
torch
.
sqrt
(
torch
.
sum
((
pts
.
unsqueeze
(
-
2
)
-
pts
.
unsqueeze
(
-
3
))
**
2
,
dim
=-
1
)
)
return
torch
.
bucketize
(
dists
,
boundaries
)
def
dict_multimap
(
fn
,
dicts
):
first
=
dicts
[
0
]
new_dict
=
{}
for
k
,
v
in
first
.
items
():
all_v
=
[
d
[
k
]
for
d
in
dicts
]
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
else
:
new_dict
[
k
]
=
fn
(
all_v
)
return
new_dict
def
one_hot
(
x
,
v_bins
):
reshaped_bins
=
v_bins
.
view
(((
1
,)
*
len
(
x
.
shape
))
+
(
len
(
v_bins
),))
diffs
=
x
[...,
None
]
-
reshaped_bins
am
=
torch
.
argmin
(
torch
.
abs
(
diffs
),
dim
=-
1
)
return
nn
.
functional
.
one_hot
(
am
,
num_classes
=
len
(
v_bins
)).
float
()
def
batched_gather
(
data
,
inds
,
dim
=
0
,
no_batch_dims
=
0
):
ranges
=
[]
for
i
,
s
in
enumerate
(
data
.
shape
[:
no_batch_dims
]):
r
=
torch
.
arange
(
s
)
r
=
r
.
view
(
*
(
*
((
1
,)
*
i
),
-
1
,
*
((
1
,)
*
(
len
(
inds
.
shape
)
-
i
-
1
))))
ranges
.
append
(
r
)
remaining_dims
=
[
slice
(
None
)
for
_
in
range
(
len
(
data
.
shape
)
-
no_batch_dims
)
]
remaining_dims
[
dim
-
no_batch_dims
if
dim
>=
0
else
dim
]
=
inds
ranges
.
extend
(
remaining_dims
)
return
data
[
ranges
]
# With tree_map, a poor man's JAX tree_map
def
dict_map
(
fn
,
dic
,
leaf_type
):
new_dict
=
{}
for
k
,
v
in
dic
.
items
():
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_map
(
fn
,
v
,
leaf_type
)
else
:
new_dict
[
k
]
=
tree_map
(
fn
,
v
,
leaf_type
)
return
new_dict
def
tree_map
(
fn
,
tree
,
leaf_type
):
if
isinstance
(
tree
,
dict
):
return
dict_map
(
fn
,
tree
,
leaf_type
)
elif
isinstance
(
tree
,
list
):
return
[
tree_map
(
fn
,
x
,
leaf_type
)
for
x
in
tree
]
elif
isinstance
(
tree
,
tuple
):
return
tuple
([
tree_map
(
fn
,
x
,
leaf_type
)
for
x
in
tree
])
elif
isinstance
(
tree
,
leaf_type
):
return
fn
(
tree
)
else
:
print
(
type
(
tree
))
raise
ValueError
(
"Not supported"
)
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
def
_fetch_dims
(
tree
):
shapes
=
[]
tree_type
=
type
(
tree
)
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
shapes
.
extend
(
_fetch_dims
(
v
))
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
shapes
.
extend
(
_fetch_dims
(
t
))
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
else
:
raise
ValueError
(
"Not supported"
)
return
shapes
@
torch
.
jit
.
ignore
def
_flat_idx_to_idx
(
flat_idx
:
int
,
dims
:
Tuple
[
int
],
)
->
Tuple
[
int
]:
idx
=
[]
for
d
in
reversed
(
dims
):
idx
.
append
(
flat_idx
%
d
)
flat_idx
=
flat_idx
//
d
return
tuple
(
reversed
(
idx
))
@
torch
.
jit
.
ignore
def
_get_minimal_slice_set
(
start
:
Sequence
[
int
],
end
:
Sequence
[
int
],
dims
:
int
,
start_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
end_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
)
->
Sequence
[
Tuple
[
int
]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def
reduce_edge_list
(
ll
):
tally
=
1
for
i
in
range
(
len
(
ll
)):
reversed_idx
=
-
1
*
(
i
+
1
)
ll
[
reversed_idx
]
*=
tally
tally
=
ll
[
reversed_idx
]
if
(
start_edges
is
None
):
start_edges
=
[
s
==
0
for
s
in
start
]
reduce_edge_list
(
start_edges
)
if
(
end_edges
is
None
):
end_edges
=
[
e
==
(
d
-
1
)
for
e
,
d
in
zip
(
end
,
dims
)]
reduce_edge_list
(
end_edges
)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if
(
len
(
start
)
==
0
):
return
[
tuple
()]
elif
(
len
(
start
)
==
1
):
return
[(
slice
(
start
[
0
],
end
[
0
]
+
1
),)]
slices
=
[]
path
=
[]
# Dimensions common to start and end can be selected directly
for
s
,
e
in
zip
(
start
,
end
):
if
(
s
==
e
):
path
.
append
(
slice
(
s
,
s
+
1
))
else
:
break
path
=
tuple
(
path
)
divergence_idx
=
len
(
path
)
# start == end, and we're done
if
(
divergence_idx
==
len
(
dims
)):
return
[
tuple
(
path
)]
def
upper
():
sdi
=
start
[
divergence_idx
]
return
[
path
+
(
slice
(
sdi
,
sdi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
start
[
divergence_idx
+
1
:],
[
d
-
1
for
d
in
dims
[
divergence_idx
+
1
:]],
dims
[
divergence_idx
+
1
:],
start_edges
=
start_edges
[
divergence_idx
+
1
:],
end_edges
=
[
1
for
_
in
end_edges
[
divergence_idx
+
1
:]]
)
]
def
lower
():
edi
=
end
[
divergence_idx
]
return
[
path
+
(
slice
(
edi
,
edi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
[
0
for
_
in
start
[
divergence_idx
+
1
:]],
end
[
divergence_idx
+
1
:],
dims
[
divergence_idx
+
1
:],
start_edges
=
[
1
for
_
in
start_edges
[
divergence_idx
+
1
:]],
end_edges
=
end_edges
[
divergence_idx
+
1
:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if
(
start_edges
[
divergence_idx
]
and
end_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]
+
1
),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif
(
start_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
# Analogous to the previous case, but the top is ragged this time
elif
(
end_edges
[
divergence_idx
]):
slices
.
extend
(
upper
())
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]
+
1
),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else
:
slices
.
extend
(
upper
())
middle_ground
=
end
[
divergence_idx
]
-
start
[
divergence_idx
]
if
(
middle_ground
>
1
):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
return
[
tuple
(
s
)
for
s
in
slices
]
@
torch
.
jit
.
ignore
def
_chunk_slice
(
t
:
torch
.
Tensor
,
flat_start
:
int
,
flat_end
:
int
,
no_batch_dims
:
int
,
)
->
torch
.
Tensor
:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims
=
t
.
shape
[:
no_batch_dims
]
start_idx
=
list
(
_flat_idx_to_idx
(
flat_start
,
batch_dims
))
# _get_minimal_slice_set is inclusive
end_idx
=
list
(
_flat_idx_to_idx
(
flat_end
-
1
,
batch_dims
))
# Get an ordered list of slices to perform
slices
=
_get_minimal_slice_set
(
start_idx
,
end_idx
,
batch_dims
,
)
sliced_tensors
=
[
t
[
s
]
for
s
in
slices
]
return
torch
.
cat
(
[
s
.
view
((
-
1
,)
+
t
.
shape
[
no_batch_dims
:])
for
s
in
sliced_tensors
]
)
def
chunk_layer
(
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
chunk_size
:
int
,
no_batch_dims
:
int
,
low_mem
:
bool
=
False
,
)
->
Any
:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if
not
(
len
(
inputs
)
>
0
):
raise
ValueError
(
"Must provide at least one input"
)
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
_fetch_dims
(
inputs
)]
orig_batch_dims
=
tuple
([
max
(
s
)
for
s
in
zip
(
*
initial_dims
)])
def
_prep_inputs
(
t
):
# TODO: make this more memory efficient. This sucks
if
(
not
low_mem
):
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
else
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
return
t
prepped_inputs
=
tensor_tree_map
(
_prep_inputs
,
inputs
)
flat_batch_dim
=
1
for
d
in
orig_batch_dims
:
flat_batch_dim
*=
d
no_chunks
=
flat_batch_dim
//
chunk_size
+
(
flat_batch_dim
%
chunk_size
!=
0
)
i
=
0
out
=
None
for
_
in
range
(
no_chunks
):
# Chunk the input
if
(
not
low_mem
):
select_chunk
=
(
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
)
else
:
select_chunk
=
(
partial
(
_chunk_slice
,
flat_start
=
i
,
flat_end
=
min
(
flat_batch_dim
,
i
+
chunk_size
),
no_batch_dims
=
len
(
orig_batch_dims
)
)
)
chunks
=
tensor_tree_map
(
select_chunk
,
prepped_inputs
)
# Run the layer on the chunk
output_chunk
=
layer
(
**
chunks
)
# Allocate space for the output
if
out
is
None
:
allocate
=
lambda
t
:
t
.
new_zeros
((
flat_batch_dim
,)
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
# Put the chunk in its pre-allocated space
out_type
=
type
(
output_chunk
)
if
out_type
is
dict
:
def
assign
(
d1
,
d2
):
for
k
,
v
in
d1
.
items
():
if
type
(
v
)
is
dict
:
assign
(
v
,
d2
[
k
])
else
:
v
[
i
:
i
+
chunk_size
]
=
d2
[
k
]
assign
(
out
,
output_chunk
)
elif
out_type
is
tuple
:
for
x1
,
x2
in
zip
(
out
,
output_chunk
):
x1
[
i
:
i
+
chunk_size
]
=
x2
elif
out_type
is
torch
.
Tensor
:
out
[
i
:
i
+
chunk_size
]
=
output_chunk
else
:
raise
ValueError
(
"Not supported"
)
i
+=
chunk_size
reshape
=
lambda
t
:
t
.
view
(
orig_batch_dims
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
reshape
,
out
)
return
out
tests/test_autochunk/openfold/triangular_attention.py
deleted
100644 → 0
View file @
438ea608
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partialmethod
,
partial
import
math
from
typing
import
Optional
,
List
import
torch
import
torch.nn
as
nn
from
.primitives
import
Linear
,
LayerNorm
,
Attention
from
.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
flatten_final_dims
,
)
class
TriangleAttention
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
starting
,
inf
=
1e9
):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Overall hidden channel dimension (not per-head)
no_heads:
Number of attention heads
"""
super
(
TriangleAttention
,
self
).
__init__
()
self
.
c_in
=
c_in
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
starting
=
starting
self
.
inf
=
inf
self
.
layer_norm
=
LayerNorm
(
self
.
c_in
)
self
.
linear
=
Linear
(
c_in
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
self
.
mha
=
Attention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
x
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
)
->
torch
.
Tensor
:
mha_inputs
=
{
"q_x"
:
x
,
"kv_x"
:
x
,
"biases"
:
biases
,
}
return
chunk_layer
(
partial
(
self
.
mha
),
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
x:
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
[*, I, J, C_in] output tensor
"""
if
mask
is
None
:
# [*, I, J]
mask
=
x
.
new_ones
(
x
.
shape
[:
-
1
],
)
# Shape annotations assume self.starting. Else, I and J are flipped
if
not
self
.
starting
:
x
=
x
.
transpose
(
-
2
,
-
3
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
# [*, I, J, C_in]
x
=
self
.
layer_norm
(
x
)
# [*, I, 1, 1, J]
mask_bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
# [*, H, I, J]
triangle_bias
=
permute_final_dims
(
self
.
linear
(
x
),
(
2
,
0
,
1
))
# [*, 1, H, I, J]
triangle_bias
=
triangle_bias
.
unsqueeze
(
-
4
)
biases
=
[
mask_bias
,
triangle_bias
]
if
chunk_size
is
not
None
:
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
)
else
:
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
)
if
not
self
.
starting
:
x
=
x
.
transpose
(
-
2
,
-
3
)
return
x
class
TriangleAttentionStartingNode
(
TriangleAttention
):
"""
Implements Algorithm 13.
"""
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
True
)
class
TriangleAttentionEndingNode
(
TriangleAttention
):
"""
Implements Algorithm 14.
"""
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
False
)
tests/test_autochunk/openfold/triangular_multiplicative_update.py
deleted
100644 → 0
View file @
438ea608
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partialmethod
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
.primitives
import
Linear
,
LayerNorm
from
.tensor_utils
import
permute_final_dims
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
"""
Implements Algorithms 11 and 12.
"""
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
=
True
):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super
(
TriangleMultiplicativeUpdate
,
self
).
__init__
()
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
_outgoing
=
_outgoing
self
.
linear_a_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_a_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_b_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_b_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_g
=
Linear
(
self
.
c_z
,
self
.
c_z
,
init
=
"gating"
)
self
.
linear_z
=
Linear
(
self
.
c_hidden
,
self
.
c_z
,
init
=
"final"
)
self
.
layer_norm_in
=
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_out
=
LayerNorm
(
self
.
c_hidden
)
self
.
sigmoid
=
nn
.
Sigmoid
()
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
"This method needs to be overridden"
)
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
z
=
self
.
layer_norm_in
(
z
)
a
=
self
.
linear_a_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
a
=
a
*
mask
b
=
self
.
linear_b_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
mask
x
=
self
.
_combine_projections
(
a
,
b
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
z
=
x
*
g
return
z
class
TriangleMultiplicationOutgoing
(
TriangleMultiplicativeUpdate
):
"""
Implements Algorithm 11.
"""
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
0
,
1
)),
permute_final_dims
(
b
,
(
2
,
1
,
0
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
class
TriangleMultiplicationIncoming
(
TriangleMultiplicativeUpdate
):
"""
Implements Algorithm 12.
"""
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
1
,
0
)),
permute_final_dims
(
b
,
(
2
,
0
,
1
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
tests/test_autochunk/test_evoformer_codegen.py
0 → 100644
View file @
8208fd02
from
functools
import
partial
import
pytest
import
torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
fastfold.model.nn.evoformer
import
EvoformerBlock
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.tracer.experimental
import
ColoTracer
,
symbolic_trace
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
,
node_mask
,
pair_mask
):
# for memory test
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# gm(node1, pair1)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print(
# "autochunk now mem:%.2f max mem:%.2f"
# % (new_now_mem - now_mem, new_max_mem - now_mem)
# )
# test forward
model
=
model
.
cuda
()
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
,
node_mask
,
pair_mask
)
fx_out
=
gm
(
node
,
pair
,
node_mask
,
pair_mask
)
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
]))
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
def
_build_openfold
():
model
=
EvoformerBlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
is_multimer
=
False
,
).
eval
().
cuda
()
return
model
def
_test_evoformer_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
"localhost"
,
port
=
free_port
(),
backend
=
"nccl"
,
)
# build model and input
model
=
_build_openfold
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node_mask
=
torch
.
randn
(
1
,
msa_len
,
pair_len
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair_mask
=
torch
.
randn
(
1
,
pair_len
,
pair_len
).
cuda
()
# trace the meta graph and setup codegen
meta_graph
=
symbolic_trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_mask_trans"
:
True
,
},
)
interp
=
MetaInfoProp
(
meta_graph
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
),
MetaTensor
(
node_mask
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair_mask
,
fake_device
=
"cuda:0"
),
)
# codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
"m"
:
node
.
to
(
torch
.
device
(
"meta"
)),
"z"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"msa_mask"
:
node_mask
.
to
(
torch
.
device
(
"meta"
)),
"pair_mask"
:
pair_mask
.
to
(
torch
.
device
(
"meta"
)),
},
concrete_args
=
{
"chunk_size"
:
None
,
"_mask_trans"
:
True
,
},
)
# graph.set_codegen(codegen)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
# assert we have inserted chunk
code
=
graph
.
python_code
(
"self"
).
src
assert
"chunk_size"
in
code
# print(code)
_test_fwd
(
model
,
gm
,
node
,
pair
,
node_mask
,
pair_mask
)
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
,
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
25
,
30
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_evoformer_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_evoformer_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
_test_evoformer_codegen
(
0
,
32
,
64
,
25
)
tests/test_autochunk/test_
autochunk
_codegen.py
→
tests/test_autochunk/test_
simple_evoformer
_codegen.py
View file @
8208fd02
...
...
@@ -5,6 +5,12 @@ import torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
simple_evoformer
import
base_evoformer
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx
import
ColoTracer
...
...
@@ -13,7 +19,6 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
from
tests.test_autochunk.evoformer.evoformer
import
evoformer_base
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
...
...
@@ -48,7 +53,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
def
_test_
autochunk
_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
def
_test_
simple_evoformer
_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
...
...
@@ -60,7 +65,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
)
# build model and input
model
=
evoformer
_base
().
cuda
()
model
=
base_
evoformer
().
cuda
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
...
...
@@ -95,13 +100,14 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()),
reason
=
'torch version is lower than 1.12.0'
)
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
'torch version is lower than 1.12.0'
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
25
,
30
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_
autochunk
_codegen
(
msa_len
,
pair_len
,
max_memory
):
def
test_
simple_evoformer
_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_
autochunk
_codegen
,
_test_
simple_evoformer
_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
...
...
@@ -110,4 +116,4 @@ def test_autochunk_codegen(msa_len, pair_len, max_memory):
if
__name__
==
"__main__"
:
_test_
autochunk
_codegen
(
0
,
32
,
64
,
25
)
_test_
simple_evoformer
_codegen
(
0
,
32
,
64
,
25
)
tests/test_autochunk/test_
autochunk
_search.py
→
tests/test_autochunk/test_
simple_evoformer
_search.py
View file @
8208fd02
...
...
@@ -5,13 +5,18 @@ import torch
import
torch.fx
import
torch.multiprocessing
as
mp
try
:
from
simple_evoformer
import
base_evoformer
HAS_REPO
=
True
except
:
HAS_REPO
=
False
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.utils
import
free_port
from
tests.test_autochunk.evoformer.evoformer
import
evoformer_base
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
():
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
...
...
@@ -57,7 +62,7 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
)
def
_test_
autochunk
_search
(
rank
,
msa_len
,
pair_len
,
max_memory
):
def
_test_
simple_evoformer
_search
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
colossalai
.
launch
(
config
=
{},
...
...
@@ -69,7 +74,7 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
)
# build model and input
model
=
evoformer
_base
().
cuda
()
model
=
base_
evoformer
().
cuda
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
...
...
@@ -84,13 +89,14 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()),
reason
=
"torch version is lower than 1.12.0"
)
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
and
HAS_REPO
),
reason
=
"torch version is lower than 1.12.0"
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
25
,
30
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_
autochunk
_search
(
msa_len
,
pair_len
,
max_memory
):
def
test_
simple_evoformer
_search
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_
autochunk
_search
,
_test_
simple_evoformer
_search
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
...
...
@@ -99,4 +105,4 @@ def test_autochunk_search(msa_len, pair_len, max_memory):
if
__name__
==
"__main__"
:
_test_
autochunk
_search
(
0
,
32
,
64
,
20
)
_test_
simple_evoformer
_search
(
0
,
32
,
64
,
20
)
tests/test_tensor/common_utils/_utils.py
View file @
8208fd02
...
...
@@ -4,6 +4,7 @@ import random
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
torch.testing
import
assert_close
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
...
...
@@ -41,14 +42,20 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
return
tensor_chunk
.
clone
()
def
tensor_equal
(
A
,
B
):
return
torch
.
allclose
(
A
,
B
,
rtol
=
1e-3
,
atol
=
1e-1
)
def
tensor_equal
(
t_a
:
torch
.
Tensor
,
t_b
:
torch
.
Tensor
,
rtol
:
float
=
1e-3
,
atol
:
float
=
1e-1
):
assert_close
(
t_a
,
t_b
,
rtol
=
rtol
,
atol
=
atol
)
return
True
def
tensor_shard_equal
(
tensor
:
torch
.
Tensor
,
shard
:
torch
.
Tensor
,
rank
,
world_size
):
def
tensor_shard_equal
(
tensor
:
torch
.
Tensor
,
shard
:
torch
.
Tensor
,
rank
:
int
,
world_size
:
int
,
rtol
:
float
=
1e-3
,
atol
:
float
=
1e-1
):
assert
tensor
.
ndim
==
shard
.
ndim
if
tensor
.
shape
==
shard
.
shape
:
return
tensor_equal
(
tensor
,
shard
)
return
tensor_equal
(
tensor
,
shard
,
rtol
,
atol
)
else
:
dims_not_eq
=
torch
.
nonzero
(
torch
.
tensor
(
tensor
.
shape
)
!=
torch
.
tensor
(
shard
.
shape
))
if
dims_not_eq
.
numel
()
==
1
:
...
...
@@ -58,7 +65,7 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
if
rank
is
None
:
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
return
tensor_equal
(
tensor
.
chunk
(
world_size
,
dim
)[
rank
],
shard
)
return
tensor_equal
(
tensor
.
chunk
(
world_size
,
dim
)[
rank
],
shard
,
rtol
,
atol
)
else
:
raise
NotImplementedError
...
...
tests/test_zero/low_level_zero/test_zero_tp.py
0 → 100644
View file @
8208fd02
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.testing
import
assert_close
import
colossalai
from
colossalai.tensor
import
ProcessGroup
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.zero
import
LowLevelZeroOptimizer
from
tests.test_tensor.common_utils
import
set_seed
,
split_param_col_tp1d
,
split_param_row_tp1d
,
tensor_shard_equal
def
strict_shard_equal
(
tensor
,
shard
,
tp_pg
,
rtol
=
1e-3
,
atol
=
1e-4
):
return
tensor_shard_equal
(
tensor
,
shard
,
tp_pg
.
tp_local_rank
(),
tp_pg
.
tp_world_size
(),
rtol
,
atol
)
class
TestModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TestModel
,
self
).
__init__
()
self
.
linear1
=
nn
.
Linear
(
32
,
128
)
self
.
act
=
nn
.
GELU
()
self
.
linear2
=
nn
.
Linear
(
128
,
32
)
def
forward
(
self
,
x
):
y
=
self
.
linear1
(
x
)
y
=
self
.
act
(
y
)
y
=
self
.
linear2
(
y
)
return
x
+
y
@
parameterize
(
"overlap_flag"
,
[
False
,
True
])
@
parameterize
(
"partition_flag"
,
[
False
,
True
])
def
exam_zero_with_tp
(
overlap_flag
,
partition_flag
):
set_seed
(
233010
)
tp_pg
=
ProcessGroup
(
tp_degree
=
2
)
with
ColoInitContext
(
device
=
get_current_device
(),
default_pg
=
tp_pg
):
hybrid_model
=
TestModel
()
torch_model
=
TestModel
().
cuda
()
for
pt
,
ph
in
zip
(
torch_model
.
parameters
(),
hybrid_model
.
parameters
()):
pt
.
data
.
copy_
(
ph
.
data
)
for
name
,
param
in
hybrid_model
.
named_parameters
():
if
'linear1'
in
name
:
split_param_row_tp1d
(
param
,
tp_pg
)
param
.
compute_spec
.
set_output_replicate
(
False
)
if
'linear2.weight'
in
name
:
split_param_col_tp1d
(
param
,
tp_pg
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
tp_pg
.
rank
()],
process_group
=
tp_pg
.
dp_process_group
())
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1
)
hybrid_optim
=
torch
.
optim
.
Adam
(
hybrid_model
.
parameters
(),
lr
=
1
)
hybrid_optim
=
LowLevelZeroOptimizer
(
hybrid_optim
,
initial_scale
=
1
,
overlap_communication
=
overlap_flag
,
partition_grad
=
partition_flag
)
dp_local_rank
=
tp_pg
.
dp_local_rank
()
set_seed
(
255
+
dp_local_rank
)
data
=
torch
.
randn
(
8
,
32
,
device
=
get_current_device
())
torch_loss
=
torch_model
(
data
).
sum
()
hybrid_loss
=
hybrid_model
(
data
).
sum
()
assert_close
(
torch_loss
,
hybrid_loss
)
torch_loss
.
backward
()
hybrid_optim
.
backward
(
hybrid_loss
)
hybrid_optim
.
sync_grad
()
torch_optim
.
step
()
hybrid_optim
.
step
()
for
(
name
,
pt
),
ph
in
zip
(
torch_model
.
named_parameters
(),
hybrid_model
.
parameters
()):
assert
strict_shard_equal
(
pt
.
data
,
ph
.
data
,
tp_pg
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
exam_zero_with_tp
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_zero_with_tp
():
world_size
=
4
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_with_tp
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment