Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c35718e8
Commit
c35718e8
authored
Nov 04, 2022
by
oahzxl
Browse files
basic chunk
parent
f8aeecef
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
95 additions
and
45 deletions
+95
-45
chunk_codegen.py
chunk_codegen.py
+94
-44
chunk_codegen_run.py
chunk_codegen_run.py
+1
-1
No files found.
chunk_codegen.py
View file @
c35718e8
...
@@ -18,16 +18,61 @@ else:
...
@@ -18,16 +18,61 @@ else:
__all__
=
[
'python_code_with_activation_checkpoint'
]
__all__
=
[
'python_code_with_activation_checkpoint'
]
def
_gen_loop_start
(
to_keep
,
chunk_size
=
2
):
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
context
=
"chunk_result = []; chunk_size = %d
\n
for gen_loop_idx in range(0, %s.shape[0], chunk_size):
\n
"
%
(
chunk_size
,
to_keep
[
0
])
new_shape
=
"["
context
+=
" chunk_tensor = "
+
to_keep
+
"[gen_loop_idx:gen_loop_idx + chunk_size, :]
\n
"
for
idx
,
i
in
enumerate
(
shape
):
if
idx
==
chunk_dim
:
new_shape
+=
"%s:%s + chunk_size"
%
(
chunk_idx_name
,
chunk_idx_name
)
else
:
new_shape
+=
":"
new_shape
+=
", "
new_shape
=
new_shape
[:
-
2
]
+
"]"
return
new_shape
def
_get_first_non_single_dim
(
shape
):
for
idx
,
i
in
enumerate
(
shape
):
if
i
==
1
:
continue
else
:
return
idx
raise
RuntimeError
(
"can not get first non single dim for shape"
,
shape
)
def
_gen_loop_start
(
chunk_input_meta
,
chunk_output
,
chunk_size
=
2
):
if
len
(
chunk_input_meta
)
==
1
:
node
=
chunk_input_meta
[
0
]
node_shape
=
node
.
meta
[
'tensor_meta'
].
shape
chunk_dim
=
_get_first_non_single_dim
(
node_shape
)
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"gen_chunk_idx"
,
node_shape
)
out_shape
=
str
(
list
(
chunk_output
.
meta
[
'tensor_meta'
].
shape
))
context
=
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d
\n
for gen_chunk_idx in range"
%
(
out_shape
,
node
.
name
,
node
.
name
,
chunk_size
)
context
+=
"(0, %s.shape[%d], chunk_size):
\n
"
%
(
node
.
name
,
chunk_dim
)
context
+=
" chunk_tensor = %s%s
\n
"
%
(
node
.
name
,
chunk_slice
)
else
:
raise
NotImplementedError
(
"input with size %d not implemented"
%
len
(
chunk_input_meta
))
return
context
return
context
def
_gen_loop_end
(
final_name
,
to_keep
):
def
_gen_loop_end
(
chunk_outputs
,
chunk_inputs
,
node_list
):
context
=
" chunk_result.append("
+
final_name
+
")
\n
"
chunk_inputs_name
=
chunk_inputs
[
0
].
name
context
+=
"chunk_result = torch.cat(chunk_result, dim=0); "
+
to_keep
[
0
]
+
" = None
\n
"
chunk_outputs_name
=
chunk_outputs
.
name
context
+=
final_name
+
" = chunk_result; chunk_result = None
\n
"
chunk_outputs_idx
=
_find_idx_by_name
(
chunk_outputs_name
,
node_list
)
chunk_output_shape
=
chunk_outputs
.
meta
[
'tensor_meta'
].
shape
chunk_dim
=
_get_first_non_single_dim
(
chunk_output_shape
)
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"gen_chunk_idx"
,
chunk_output_shape
)
context
=
" chunk_result%s = %s
\n
"
%
(
chunk_slice
,
chunk_outputs_name
)
context
+=
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
# determine if its the last use for chunk input
users_name
=
list
(
chunk_inputs
[
0
].
users
.
keys
())
if
all
([
_find_idx_by_name
(
user
.
name
,
node_list
)
<=
chunk_outputs_idx
for
user
in
users_name
]):
context
+=
"; %s = None"
%
chunk_inputs_name
context
+=
"
\n
"
return
context
return
context
...
@@ -44,7 +89,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
...
@@ -44,7 +89,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
for
input_node
in
node
.
_input_nodes
.
keys
():
for
input_node
in
node
.
_input_nodes
.
keys
():
node_repr
=
repr
(
input_node
)
node_repr
=
repr
(
input_node
)
if
input_node
not
in
nodes
and
node_repr
not
in
input_nodes
:
if
input_node
not
in
nodes
and
node_repr
not
in
input_nodes
:
input_nodes
.
append
(
node_repr
)
input_nodes
.
append
(
input_node
)
# if a node has a user node which is not in the node list
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
# we treat that user node as the node receiving the current node output
...
@@ -52,11 +97,18 @@ def _find_input_and_output_nodes(nodes: List[Node]):
...
@@ -52,11 +97,18 @@ def _find_input_and_output_nodes(nodes: List[Node]):
for
output_node
in
node
.
users
.
keys
():
for
output_node
in
node
.
users
.
keys
():
node_repr
=
repr
(
node
)
node_repr
=
repr
(
node
)
if
output_node
not
in
nodes
and
node_repr
not
in
output_nodes
:
if
output_node
not
in
nodes
and
node_repr
not
in
output_nodes
:
output_nodes
.
append
(
node_repr
)
output_nodes
.
append
(
output_node
)
return
input_nodes
,
output_nodes
return
input_nodes
,
output_nodes
def
_find_idx_by_name
(
name
,
nodes_list
):
for
idx
,
node
in
enumerate
(
nodes_list
):
if
node
.
name
==
name
:
return
idx
raise
RuntimeError
(
"name %s not found in node list"
%
name
)
def
_find_offload_regions
(
nodes
:
List
[
Node
]):
def
_find_offload_regions
(
nodes
:
List
[
Node
]):
"""This function is to find the offload regions
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
In pofo algorithm, during annotation, we will annotate the offload region with the
...
@@ -290,7 +342,7 @@ def emit_ckpt_func(body,
...
@@ -290,7 +342,7 @@ def emit_ckpt_func(body,
body
.
append
(
usage
)
body
.
append
(
usage
)
def
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node_func
,
delete_unused_value_func
):
def
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node_func
,
delete_unused_value_func
,
meta_nodes
):
"""Emit code with nested activation checkpoint
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
this function to emit the activation checkpoint codes.
...
@@ -304,7 +356,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -304,7 +356,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
"""
"""
# find the offload regions
# find the offload regions
chunk_regions
=
[(
1
,
4
)]
chunk_regions
=
[(
2
,
5
)]
chunk_starts
=
[
item
[
0
]
for
item
in
chunk_regions
]
chunk_starts
=
[
item
[
0
]
for
item
in
chunk_regions
]
chunk_ends
=
[
item
[
1
]
for
item
in
chunk_regions
]
chunk_ends
=
[
item
[
1
]
for
item
in
chunk_regions
]
chunk_inputs
=
[]
chunk_inputs
=
[]
...
@@ -319,46 +371,44 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -319,46 +371,44 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
inputs
,
outputs
=
_find_input_and_output_nodes
(
offload_node_list
)
inputs
,
outputs
=
_find_input_and_output_nodes
(
offload_node_list
)
chunk_inputs
.
append
(
inputs
)
chunk_inputs
.
append
(
inputs
)
chunk_outputs
.
append
(
outputs
)
chunk_outputs
.
append
(
outputs
)
chunk_inputs_idx
=
[[
_find_idx_by_name
(
j
.
name
,
node_list
)
for
j
in
i
]
for
i
in
chunk_inputs
]
chunk_outputs_idx
=
[[
_find_idx_by_name
(
j
.
name
,
node_list
)
for
j
in
i
]
for
i
in
chunk_outputs
]
chunk_inputs_names
=
[]
for
i
in
chunk_inputs
:
for
j
in
i
:
chunk_inputs_names
.
append
(
j
.
name
)
# this flag is to prevent repeated insert of save tensors
# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
# hooks definition in ckpt_func
node_idx
=
0
node_idx
=
0
chunk_var
=
[]
region_idx
=
0
while
node_idx
<
len
(
node_list
):
while
node_idx
<
len
(
node_list
):
# break if we finish the processing all the nodes
if
node_idx
>=
len
(
node_list
):
break
# process node in forward function
else
:
node
=
node_list
[
node_idx
]
node
=
node_list
[
node_idx
]
if
node_idx
in
chunk_starts
:
if
node_idx
in
chunk_starts
:
within_chunk_region
=
True
within_chunk_region
=
True
# save chunk input var, dont delete it
chunk_var
.
append
(
node
.
args
[
0
].
name
)
# add for loop
# add for loop
body
.
append
(
_gen_loop_start
(
chunk_var
[
0
]))
chunk_input_meta
=
[
meta_nodes
[
i
]
for
i
in
chunk_inputs_idx
[
region_idx
]]
body
.
append
(
_gen_loop_start
(
chunk_input_meta
,
node_list
[
chunk_ends
[
region_idx
]]))
if
within_chunk_region
:
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
emit_node_func
(
node
,
body
)
# replace input var with chunk var
# replace input var with chunk var
if
node_idx
in
chunk_starts
:
if
node_idx
in
chunk_starts
:
body
[
-
1
]
=
body
[
-
1
].
replace
(
"("
+
chunk_
var
[
0
]
+
")"
,
'(chunk_tensor)'
)
body
[
-
1
]
=
body
[
-
1
].
replace
(
"("
+
chunk_
inputs
[
region_idx
][
0
].
name
+
")"
,
'(chunk_tensor)'
)
body
[
-
1
]
=
' '
+
body
[
-
1
]
body
[
-
1
]
=
' '
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_
var
)
delete_unused_value_func
(
node
,
body
,
chunk_
inputs_names
)
else
:
else
:
emit_node_func
(
node
,
body
)
emit_node_func
(
node
,
body
)
if
node_idx
not
in
chunk_inputs
:
if
node_idx
not
in
chunk_inputs
:
delete_unused_value_func
(
node
,
body
,
chunk_
var
)
delete_unused_value_func
(
node
,
body
,
chunk_
inputs_names
)
if
node_idx
in
chunk_ends
:
if
node_idx
in
chunk_ends
:
body
.
append
(
_gen_loop_end
(
node
.
name
,
chunk_var
))
body
.
append
(
_gen_loop_end
(
node
,
chunk_inputs
[
region_idx
],
node_list
))
chunk_var
=
[]
within_chunk_region
=
False
within_chunk_region
=
False
region_idx
+=
1
node_idx
+=
1
node_idx
+=
1
...
@@ -562,7 +612,7 @@ if CODEGEN_AVAILABLE:
...
@@ -562,7 +612,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
# will use nested type of activation checkpoint codegen
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
)
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
,
self
.
meta_node
)
if
len
(
body
)
==
0
:
if
len
(
body
)
==
0
:
# If the Graph has no non-placeholder nodes, no lines for the body
# If the Graph has no non-placeholder nodes, no lines for the body
...
...
chunk_codegen_run.py
View file @
c35718e8
...
@@ -70,7 +70,7 @@ def _run_offload_codegen(rank):
...
@@ -70,7 +70,7 @@ def _run_offload_codegen(rank):
# setattr(node, "activation_offload", [0, True, False])
# setattr(node, "activation_offload", [0, True, False])
codegen
=
ChunkCodeGen
(
gm_prop
)
codegen
=
ChunkCodeGen
(
gm_prop
)
#
graph.set_codegen(codegen)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
gm
.
recompile
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment