Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c35718e8
"vscode:/vscode.git/clone" did not exist on "2edbef13cc2f08e3d74ea72a68d2299f3e7cdbb7"
Commit
c35718e8
authored
Nov 04, 2022
by
oahzxl
Browse files
basic chunk
parent
f8aeecef
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
95 additions
and
45 deletions
+95
-45
chunk_codegen.py
chunk_codegen.py
+94
-44
chunk_codegen_run.py
chunk_codegen_run.py
+1
-1
No files found.
chunk_codegen.py
View file @
c35718e8
...
...
@@ -18,16 +18,61 @@ else:
__all__
=
[
'python_code_with_activation_checkpoint'
]
def
_gen_loop_start
(
to_keep
,
chunk_size
=
2
):
context
=
"chunk_result = []; chunk_size = %d
\n
for gen_loop_idx in range(0, %s.shape[0], chunk_size):
\n
"
%
(
chunk_size
,
to_keep
[
0
])
context
+=
" chunk_tensor = "
+
to_keep
+
"[gen_loop_idx:gen_loop_idx + chunk_size, :]
\n
"
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
new_shape
=
"["
for
idx
,
i
in
enumerate
(
shape
):
if
idx
==
chunk_dim
:
new_shape
+=
"%s:%s + chunk_size"
%
(
chunk_idx_name
,
chunk_idx_name
)
else
:
new_shape
+=
":"
new_shape
+=
", "
new_shape
=
new_shape
[:
-
2
]
+
"]"
return
new_shape
def
_get_first_non_single_dim
(
shape
):
for
idx
,
i
in
enumerate
(
shape
):
if
i
==
1
:
continue
else
:
return
idx
raise
RuntimeError
(
"can not get first non single dim for shape"
,
shape
)
def
_gen_loop_start
(
chunk_input_meta
,
chunk_output
,
chunk_size
=
2
):
if
len
(
chunk_input_meta
)
==
1
:
node
=
chunk_input_meta
[
0
]
node_shape
=
node
.
meta
[
'tensor_meta'
].
shape
chunk_dim
=
_get_first_non_single_dim
(
node_shape
)
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"gen_chunk_idx"
,
node_shape
)
out_shape
=
str
(
list
(
chunk_output
.
meta
[
'tensor_meta'
].
shape
))
context
=
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d
\n
for gen_chunk_idx in range"
%
(
out_shape
,
node
.
name
,
node
.
name
,
chunk_size
)
context
+=
"(0, %s.shape[%d], chunk_size):
\n
"
%
(
node
.
name
,
chunk_dim
)
context
+=
" chunk_tensor = %s%s
\n
"
%
(
node
.
name
,
chunk_slice
)
else
:
raise
NotImplementedError
(
"input with size %d not implemented"
%
len
(
chunk_input_meta
))
return
context
def
_gen_loop_end
(
final_name
,
to_keep
):
context
=
" chunk_result.append("
+
final_name
+
")
\n
"
context
+=
"chunk_result = torch.cat(chunk_result, dim=0); "
+
to_keep
[
0
]
+
" = None
\n
"
context
+=
final_name
+
" = chunk_result; chunk_result = None
\n
"
def
_gen_loop_end
(
chunk_outputs
,
chunk_inputs
,
node_list
):
chunk_inputs_name
=
chunk_inputs
[
0
].
name
chunk_outputs_name
=
chunk_outputs
.
name
chunk_outputs_idx
=
_find_idx_by_name
(
chunk_outputs_name
,
node_list
)
chunk_output_shape
=
chunk_outputs
.
meta
[
'tensor_meta'
].
shape
chunk_dim
=
_get_first_non_single_dim
(
chunk_output_shape
)
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"gen_chunk_idx"
,
chunk_output_shape
)
context
=
" chunk_result%s = %s
\n
"
%
(
chunk_slice
,
chunk_outputs_name
)
context
+=
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
# determine if its the last use for chunk input
users_name
=
list
(
chunk_inputs
[
0
].
users
.
keys
())
if
all
([
_find_idx_by_name
(
user
.
name
,
node_list
)
<=
chunk_outputs_idx
for
user
in
users_name
]):
context
+=
"; %s = None"
%
chunk_inputs_name
context
+=
"
\n
"
return
context
...
...
@@ -44,7 +89,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
for
input_node
in
node
.
_input_nodes
.
keys
():
node_repr
=
repr
(
input_node
)
if
input_node
not
in
nodes
and
node_repr
not
in
input_nodes
:
input_nodes
.
append
(
node_repr
)
input_nodes
.
append
(
input_node
)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
...
...
@@ -52,11 +97,18 @@ def _find_input_and_output_nodes(nodes: List[Node]):
for
output_node
in
node
.
users
.
keys
():
node_repr
=
repr
(
node
)
if
output_node
not
in
nodes
and
node_repr
not
in
output_nodes
:
output_nodes
.
append
(
node_repr
)
output_nodes
.
append
(
output_node
)
return
input_nodes
,
output_nodes
def
_find_idx_by_name
(
name
,
nodes_list
):
for
idx
,
node
in
enumerate
(
nodes_list
):
if
node
.
name
==
name
:
return
idx
raise
RuntimeError
(
"name %s not found in node list"
%
name
)
def
_find_offload_regions
(
nodes
:
List
[
Node
]):
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
...
...
@@ -290,7 +342,7 @@ def emit_ckpt_func(body,
body
.
append
(
usage
)
def
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node_func
,
delete_unused_value_func
):
def
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node_func
,
delete_unused_value_func
,
meta_nodes
):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
...
...
@@ -304,7 +356,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
"""
# find the offload regions
chunk_regions
=
[(
1
,
4
)]
chunk_regions
=
[(
2
,
5
)]
chunk_starts
=
[
item
[
0
]
for
item
in
chunk_regions
]
chunk_ends
=
[
item
[
1
]
for
item
in
chunk_regions
]
chunk_inputs
=
[]
...
...
@@ -319,48 +371,46 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
inputs
,
outputs
=
_find_input_and_output_nodes
(
offload_node_list
)
chunk_inputs
.
append
(
inputs
)
chunk_outputs
.
append
(
outputs
)
chunk_inputs_idx
=
[[
_find_idx_by_name
(
j
.
name
,
node_list
)
for
j
in
i
]
for
i
in
chunk_inputs
]
chunk_outputs_idx
=
[[
_find_idx_by_name
(
j
.
name
,
node_list
)
for
j
in
i
]
for
i
in
chunk_outputs
]
chunk_inputs_names
=
[]
for
i
in
chunk_inputs
:
for
j
in
i
:
chunk_inputs_names
.
append
(
j
.
name
)
# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
node_idx
=
0
chunk_var
=
[]
region_idx
=
0
while
node_idx
<
len
(
node_list
):
# break if we finish the processing all the nodes
if
node_idx
>=
len
(
node_list
):
break
node
=
node_list
[
node_idx
]
# process node in forward function
else
:
node
=
node_list
[
node_idx
]
if
node_idx
in
chunk_starts
:
within_chunk_region
=
True
# add for loop
chunk_input_meta
=
[
meta_nodes
[
i
]
for
i
in
chunk_inputs_idx
[
region_idx
]]
body
.
append
(
_gen_loop_start
(
chunk_input_meta
,
node_list
[
chunk_ends
[
region_idx
]]))
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
# replace input var with chunk var
if
node_idx
in
chunk_starts
:
within_chunk_region
=
True
# save chunk input var, dont delete it
chunk_var
.
append
(
node
.
args
[
0
].
name
)
# add for loop
body
.
append
(
_gen_loop_start
(
chunk_var
[
0
]))
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
# replace input var with chunk var
if
node_idx
in
chunk_starts
:
body
[
-
1
]
=
body
[
-
1
].
replace
(
"("
+
chunk_var
[
0
]
+
")"
,
'(chunk_tensor)'
)
body
[
-
1
]
=
' '
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_var
)
body
[
-
1
]
=
body
[
-
1
].
replace
(
"("
+
chunk_inputs
[
region_idx
][
0
].
name
+
")"
,
'(chunk_tensor)'
)
body
[
-
1
]
=
' '
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
else
:
emit_node_func
(
node
,
body
)
if
node_idx
not
in
chunk_inputs
:
delete_unused_value_func
(
node
,
body
,
chunk_
var
)
else
:
emit_node_func
(
node
,
body
)
if
node_idx
not
in
chunk_inputs
:
delete_unused_value_func
(
node
,
body
,
chunk_
inputs_names
)
if
node_idx
in
chunk_ends
:
body
.
append
(
_gen_loop_end
(
node
.
name
,
chunk_
var
))
chunk_var
=
[]
within_chunk_region
=
False
if
node_idx
in
chunk_ends
:
body
.
append
(
_gen_loop_end
(
node
,
chunk_
inputs
[
region_idx
],
node_list
))
within_chunk_region
=
False
region_idx
+=
1
node_idx
+=
1
node_idx
+=
1
if
CODEGEN_AVAILABLE
:
...
...
@@ -562,7 +612,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
)
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
,
self
.
meta_node
)
if
len
(
body
)
==
0
:
# If the Graph has no non-placeholder nodes, no lines for the body
...
...
chunk_codegen_run.py
View file @
c35718e8
...
...
@@ -70,7 +70,7 @@ def _run_offload_codegen(rank):
# setattr(node, "activation_offload", [0, True, False])
codegen
=
ChunkCodeGen
(
gm_prop
)
#
graph.set_codegen(codegen)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment