Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
820ea4d0
Commit
820ea4d0
authored
Nov 02, 2022
by
oahzxl
Browse files
align evoformer
parent
86f2a314
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
67 additions
and
192 deletions
+67
-192
chunk_codegen.py
chunk_codegen.py
+21
-122
chunk_codegen_run.py
chunk_codegen_run.py
+34
-63
evoformer/evoformer.py
evoformer/evoformer.py
+6
-1
evoformer/kernel.py
evoformer/kernel.py
+1
-1
evoformer/msa.py
evoformer/msa.py
+1
-1
evoformer/triangle.py
evoformer/triangle.py
+4
-4
No files found.
chunk_codegen.py
View file @
820ea4d0
import
colossalai
import
torch
import
copy
from
typing
import
List
,
Callable
,
Any
,
Tuple
,
Dict
,
Iterable
try
:
...
...
@@ -17,74 +18,18 @@ else:
__all__
=
[
'python_code_with_activation_checkpoint'
]
def
_gen_saved_tensors_hooks
():
"""
Generate saved tensors hooks
"""
pack_hook
=
"""def pack_hook_input(self, x):
if getattr(x, "offload", False):
return (x.device, x.cpu())
else:
return x
def pack_hook_no_input(self, x):
if getattr(x, "offload", True):
return (x.device, x.cpu())
else:
return x
"""
unpack_hook
=
"""def unpack_hook(self, packed):
if isinstance(packed, tuple):
device, tensor = packed
return tensor.to(device)
else:
return packed
"""
return
pack_hook
,
unpack_hook
def
_gen_loop_5
(
to_keep
):
context
=
"chunk_result = []
\n
for gen_loop_idx in range(4):
\n
"
context
+=
" chunk_tensor = "
+
to_keep
+
"[gen_loop_idx, :]
\n
"
def
_gen_loop_start
(
to_keep
,
chunk_size
=
2
):
context
=
"chunk_result = []; chunk_size = %d
\n
for gen_loop_idx in range(0, %s.shape[0], chunk_size):
\n
"
%
(
chunk_size
,
to_keep
[
0
])
context
+=
" chunk_tensor = "
+
to_keep
+
"[gen_loop_idx:gen_loop_idx + chunk_size, :]
\n
"
return
context
def
_gen_loop_
5_final
(
final_name
,
to_keep
):
def
_gen_loop_
end
(
final_name
,
to_keep
):
context
=
" chunk_result.append("
+
final_name
+
")
\n
"
context
+=
"chunk_result = torch.cat(chunk_result, dim=0); "
+
to_keep
[
0
]
+
" = None
\n
"
context
+=
final_name
+
" = chunk_result; chunk_result = None
\n
"
return
context
def
_gen_save_tensors_hooks_context
(
offload_input
=
True
)
->
str
:
"""Generate customized saved_tensors_hooks
Args:
offload_input (bool, optional): whether we need offload input, if offload_input=False,
we will use self.pack_hook_no_input instead. Defaults to True.
Returns:
str: generated context
"""
if
offload_input
:
context
=
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):
\n
"
else
:
context
=
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):
\n
"
return
context
def
_gen_save_on_cpu_context
():
"""
Generate save on cpu context
"""
context
=
"with torch.autograd.graph.save_on_cpu(pin_memory=True):
\n
"
return
context
def
_find_input_and_output_nodes
(
nodes
:
List
[
Node
]):
"""
...
...
@@ -112,49 +57,6 @@ def _find_input_and_output_nodes(nodes: List[Node]):
return
input_nodes
,
output_nodes
def
_find_ckpt_regions
(
nodes
:
List
[
Node
]):
"""
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_nodes
=
[]
ckpt_regions
=
[]
start
=
-
1
end
=
-
1
current_region
=
None
for
idx
,
node
in
enumerate
(
nodes
):
if
hasattr
(
node
,
'activation_checkpoint'
):
act_ckpt_label
=
node
.
activation_checkpoint
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
if
current_region
is
None
:
current_region
=
act_ckpt_label
start
=
idx
# if activation checkpoint has changed
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
if
act_ckpt_label
!=
current_region
:
assert
start
!=
-
1
ckpt_regions
.
append
((
start
,
idx
-
1
))
current_region
=
act_ckpt_label
start
=
idx
end
=
-
1
elif
current_region
is
not
None
and
not
hasattr
(
node
,
'activation_checkpoint'
):
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end
=
idx
-
1
assert
start
!=
-
1
and
end
!=
-
1
ckpt_regions
.
append
((
start
,
end
))
start
=
end
=
-
1
current_region
=
None
else
:
pass
return
ckpt_regions
def
_find_offload_regions
(
nodes
:
List
[
Node
]):
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
...
...
@@ -400,12 +302,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
"""
ckpt_regions
=
_find_nested_ckpt_regions
(
nodes
,
0
)
start_idx
=
[
item
[
0
]
for
item
in
ckpt_regions
]
end_idx
=
[
item
[
1
]
for
item
in
ckpt_regions
]
# find the offload regions
chunk_regions
,
chunk_labels
=
_find_offload_regions
(
nodes
)
chunk_regions
=
[(
1
,
4
)]
chunk_starts
=
[
item
[
0
]
for
item
in
chunk_regions
]
chunk_ends
=
[
item
[
1
]
for
item
in
chunk_regions
]
chunk_inputs
=
[]
...
...
@@ -424,7 +323,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
node_idx
=
0
to_keep
=
[]
chunk_var
=
[]
while
node_idx
<
len
(
node_list
):
# break if we finish the processing all the nodes
if
node_idx
>=
len
(
node_list
):
...
...
@@ -435,28 +334,30 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
node
=
node_list
[
node_idx
]
if
node_idx
in
chunk_starts
:
# save chunk input var, dont delete it
to_keep
.
extend
(
node
.
args
[
0
].
name
)
within_chunk_region
=
True
# add for loop
body
.
append
(
_gen_loop_5
(
to_keep
[
0
]))
# change first node's input to new chunked var
node_args
=
list
(
node
.
args
)
node_args
[
0
]
=
'chunk_tensor'
# save chunk input var, dont delete it
chunk_var
.
append
(
node
.
args
[
0
].
name
)
# add for loop
body
.
append
(
_gen_loop_start
(
chunk_var
[
0
]))
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
# replace input var with chunk var
if
node_idx
in
chunk_starts
:
body
[
-
1
]
=
body
[
-
1
].
replace
(
"("
+
chunk_var
[
0
]
+
")"
,
'(chunk_tensor)'
)
body
[
-
1
]
=
' '
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
to_keep
)
delete_unused_value_func
(
node
,
body
,
chunk_var
)
else
:
emit_node_func
(
node
,
body
)
if
node_idx
not
in
chunk_inputs
:
delete_unused_value_func
(
node
,
body
,
to_keep
)
delete_unused_value_func
(
node
,
body
,
chunk_var
)
if
node_idx
in
chunk_ends
:
body
.
append
(
_gen_loop_
5_final
(
node
.
name
,
to_keep
))
to_keep
=
[]
body
.
append
(
_gen_loop_
end
(
node
.
name
,
chunk_var
))
chunk_var
=
[]
within_chunk_region
=
False
node_idx
+=
1
...
...
@@ -580,9 +481,7 @@ if CODEGEN_AVAILABLE:
body
.
append
(
'
\n
'
)
return
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
for
n
in
nodes_to_delete
:
if
n
.
name
in
to_keep
:
nodes_to_delete
.
remove
(
n
)
nodes_to_delete
=
[
i
for
i
in
nodes_to_delete
if
i
.
name
not
in
to_keep
]
if
len
(
nodes_to_delete
):
to_delete_str
=
' = '
.
join
([
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
'None'
])
body
.
append
(
f
';
{
to_delete_str
}
\n
'
)
...
...
chunk_codegen_run.py
View file @
820ea4d0
...
...
@@ -9,60 +9,39 @@ import colossalai
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx.graph_module
import
ColoGraphModule
try
:
from
chunk_codegen
import
ChunkCodeGen
with_codegen
=
True
except
:
# fall back to older pytorch version
from
chunk_codegen
import
python_code_with_activation_checkpoint
with_codegen
=
False
class
MyNet
(
torch
.
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
linear0
=
torch
.
nn
.
Linear
(
4
,
4
)
self
.
linear1
=
torch
.
nn
.
Linear
(
4
,
4
)
self
.
linear2
=
torch
.
nn
.
Linear
(
4
,
4
)
self
.
linear3
=
torch
.
nn
.
Linear
(
4
,
4
)
self
.
linear4
=
torch
.
nn
.
Linear
(
4
,
4
)
self
.
linear5
=
torch
.
nn
.
Linear
(
4
,
4
)
self
.
linear6
=
torch
.
nn
.
Linear
(
4
,
4
)
def
forward
(
self
,
x
):
x
=
self
.
linear0
(
x
)
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
self
.
linear4
(
x
)
x
=
self
.
linear5
(
x
)
x
=
self
.
linear6
(
x
)
return
x
from
evoformer.evoformer
import
evoformer_base
from
chunk_codegen
import
ChunkCodeGen
with_codegen
=
True
def
_is_all_gradient_close
(
m
:
torch
.
nn
.
Module
,
gm
:
GraphModule
)
->
bool
:
for
m_p
,
gm_p
in
zip
(
m
.
parameters
(),
gm
.
parameters
()):
if
not
torch
.
allclose
(
m_p
.
grad
,
gm_p
.
grad
):
if
m_p
.
grad
is
not
None
and
not
torch
.
allclose
(
m_p
.
grad
,
gm_p
.
grad
):
return
False
return
True
def
_test_fwd_and_bwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
data
:
torch
.
Tensor
):
def
_is_all_param_close
(
m
:
torch
.
nn
.
Module
,
gm
:
GraphModule
)
->
bool
:
for
m_p
,
gm_p
in
zip
(
m
.
parameters
(),
gm
.
parameters
()):
if
m_p
.
grad
is
not
None
and
not
torch
.
allclose
(
m_p
.
data
,
gm_p
.
data
):
return
False
return
True
def
_test_fwd_and_bwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
):
# test forward
non_fx_out
=
model
(
data
)
fx_out
=
gm
(
data
)
print
(
non_fx_out
.
shape
,
fx_out
.
shape
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
),
"fx_out doesn't comply with original output"
non_fx_out
=
model
(
node
.
clone
(),
pair
.
clone
()
)
fx_out
=
gm
(
node
.
clone
(),
pair
.
clone
()
)
assert
torch
.
equal
(
non_fx_out
[
0
],
fx_out
[
0
])
,
"
fx_out
doesn't comply with original output"
assert
torch
.
equal
(
non_fx_out
[
1
]
,
fx_out
[
1
]
),
"fx_out doesn't comply with original output"
# test barckward
loss0
=
non_fx_out
.
sum
()
loss0
.
backward
()
loss1
=
fx_out
.
sum
()
loss1
.
backward
()
assert
_is_all_gradient_close
(
model
,
gm
),
"gm doesn't have the same gradient as original one"
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()
# loss0.backward()
# loss1 = fx_out[0].sum() + fx_out[1].sum()
# loss1.backward()
# assert _is_all_param_close(model, gm)
# assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
def
_run_offload_codegen
(
rank
):
...
...
@@ -70,30 +49,22 @@ def _run_offload_codegen(rank):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
# build model and input
model
=
MyNet
().
cuda
()
data
=
torch
.
rand
(
4
,
4
).
cuda
()
model
=
evoformer_base
().
cuda
()
node
=
torch
.
randn
(
1
,
16
,
32
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
32
,
32
,
128
).
cuda
()
# trace the module and replace codegen
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
graph
=
tracer
.
trace
(
model
)
codegen
=
ChunkCodeGen
()
graph
.
set_codegen
(
codegen
)
# annotate the activation offload part
# also annotate the activation_checkpoint so we could test both types
# of input offload
for
node
in
graph
.
nodes
:
if
node
.
name
==
"linear0"
:
setattr
(
node
,
"activation_offload"
,
[
0
,
True
,
False
])
if
node
.
name
==
"linear1"
:
setattr
(
node
,
"activation_offload"
,
[
0
,
True
,
False
])
# if node.name == "linear2":
# setattr(node, "activation_offload", [1, True, True])
# if node.name == "linear4":
# setattr(node, "activation_offload", [2, False, True])
# if node.name == "linear5":
# setattr(node, "activation_checkpoint", [0])
# setattr(node, "activation_offload", True)
# codegen = ChunkCodeGen()
# graph.set_codegen(codegen)
# annotate the chunk part
# for node in graph.nodes:
# if node.name == "linear0":
# setattr(node, "activation_offload", [0, True, False])
# if node.name == "linear1":
# setattr(node, "activation_offload", [0, True, False])
gm
=
ColoGraphModule
(
copy
.
deepcopy
(
model
),
graph
)
gm
.
recompile
()
...
...
@@ -102,7 +73,7 @@ def _run_offload_codegen(rank):
code
=
graph
.
python_code
(
"self"
).
src
print
(
code
)
_test_fwd_and_bwd
(
model
,
gm
,
data
)
_test_fwd_and_bwd
(
model
,
gm
,
node
,
pair
)
gpc
.
destroy
()
...
...
evoformer/evoformer.py
View file @
820ea4d0
...
...
@@ -28,7 +28,7 @@ class Evoformer(nn.Module):
super
(
Evoformer
,
self
).
__init__
()
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
3
):
for
_
in
range
(
1
):
self
.
blocks
.
append
(
EvoformerBlock
(
d_node
,
d_pair
))
def
forward
(
self
,
node
,
pair
):
...
...
@@ -36,6 +36,11 @@ class Evoformer(nn.Module):
node
,
pair
=
b
(
node
,
pair
)
return
node
,
pair
def
evoformer_tiny
():
return
Evoformer
(
d_node
=
64
,
d_pair
=
32
)
def
evoformer_base
():
return
Evoformer
(
d_node
=
256
,
d_pair
=
128
)
...
...
evoformer/kernel.py
View file @
820ea4d0
...
...
@@ -8,7 +8,7 @@ def bias_sigmod_ele(y, bias, z):
def
bias_dropout_add
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
dropmask
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
out
=
(
x
+
bias
)
*
F
.
dropout
(
dropmask
,
p
=
prob
,
training
=
Tru
e
)
out
=
(
x
+
bias
)
*
F
.
dropout
(
dropmask
,
p
=
prob
,
training
=
Fals
e
)
out
=
residual
+
out
return
out
...
...
evoformer/msa.py
View file @
820ea4d0
...
...
@@ -45,7 +45,7 @@ class MSARowAttentionWithPairBias(nn.Module):
# b = rearrange(b, 'b q k h -> b h q k')
M
=
self
.
attention
(
M
,
b
)
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:]
,
device
=
M
.
device
,
dtype
=
M
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:]
).
to
(
M
.
device
).
to
(
M
.
dtype
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
)
...
...
evoformer/triangle.py
View file @
820ea4d0
...
...
@@ -51,7 +51,7 @@ class TriangleMultiplicationOutgoing(nn.Module):
ab
=
torch
.
einsum
(
'bikd,bjkd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]
,
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]
).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
g
,
...
...
@@ -97,7 +97,7 @@ class TriangleMultiplicationIncoming(nn.Module):
ab
=
torch
.
einsum
(
'bkid,bkjd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]
,
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]
).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
g
,
...
...
@@ -134,7 +134,7 @@ class TriangleAttentionStartingNode(nn.Module):
Z
=
self
.
attention
(
Z
,
b
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]
,
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]
).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
...
...
@@ -168,7 +168,7 @@ class TriangleAttentionEndingNode(nn.Module):
Z
=
self
.
attention
(
Z
,
b
)
Z
=
Z
.
transpose
(
-
2
,
-
3
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:]
,
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:]
).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment