Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
78cfe436
Commit
78cfe436
authored
Nov 02, 2022
by
oahzxl
Browse files
basic chunk
parent
87cddf7e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
40 deletions
+41
-40
chunk_codegen.py
chunk_codegen.py
+33
-33
chunk_codegen_run.py
chunk_codegen_run.py
+8
-7
No files found.
chunk_codegen.py
View file @
78cfe436
...
...
@@ -46,6 +46,19 @@ def pack_hook_no_input(self, x):
return
pack_hook
,
unpack_hook
def
_gen_loop_5
(
to_keep
):
context
=
"chunk_result = []
\n
for gen_loop_idx in range(4):
\n
"
context
+=
" chunk_tensor = "
+
to_keep
+
"[gen_loop_idx, :]
\n
"
return
context
def
_gen_loop_5_final
(
final_name
,
to_keep
):
context
=
" chunk_result.append("
+
final_name
+
")
\n
"
context
+=
"chunk_result = torch.cat(chunk_result, dim=0); "
+
to_keep
[
0
]
+
" = None
\n
"
context
+=
final_name
+
" = chunk_result; chunk_result = None
\n
"
return
context
def
_gen_save_tensors_hooks_context
(
offload_input
=
True
)
->
str
:
"""Generate customized saved_tensors_hooks
...
...
@@ -410,57 +423,40 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
is_hook_inserted
=
False
node_idx
=
0
while
1
:
to_keep
=
[]
while
node_idx
<
len
(
node_list
):
# break if we finish the processing all the nodes
if
node_idx
>=
len
(
node_list
):
break
# process ckpt_regions
if
node_idx
in
start_idx
:
ckpt_node_list
=
node_list
[
node_idx
:
end_idx
[
start_idx
.
index
(
node_idx
)]
+
1
]
emit_ckpt_func
(
body
,
ckpt_func
,
ckpt_node_list
,
emit_node_func
,
delete_unused_value_func
)
node_idx
+=
len
(
ckpt_node_list
)
# process node in forward function
else
:
node
=
node_list
[
node_idx
]
if
node_idx
in
chunk_starts
:
chunk_label
=
chunk_labels
[
chunk_starts
.
index
(
node_idx
)]
_
,
chunk_input
,
chunk_bar
=
chunk_label
# save chunk input var, dont delete it
to_keep
.
extend
(
node
.
args
[
0
].
name
)
within_chunk_region
=
True
# insert hook functions if needed
if
not
is_hook_inserted
:
pack_hook
,
unpack_hook
=
_gen_saved_tensors_hooks
()
ckpt_func
.
insert
(
0
,
"
\n
"
.
join
([
pack_hook
,
unpack_hook
])
+
"
\n
"
)
is_hook_inserted
=
True
if
chunk_input
and
chunk_bar
:
body
.
append
(
_gen_save_on_cpu_context
())
elif
chunk_input
:
for
par
in
chunk_inputs
[
chunk_label
[
0
]]:
body
.
append
(
f
"setattr(
{
par
}
, 'offload', True)
\n
"
)
body
.
append
(
_gen_save_tensors_hooks_context
(
offload_input
=
True
))
else
:
for
par
in
chunk_inputs
[
chunk_label
[
0
]]:
body
.
append
(
f
"setattr(
{
par
}
, 'offload', False)
\n
"
)
body
.
append
(
_gen_save_tensors_hooks_context
(
offload_input
=
False
))
# add for loop
body
.
append
(
_gen_loop_5
(
to_keep
[
0
]))
# change first node's input to new chunked var
node_args
=
list
(
node
.
args
)
node_args
[
0
]
=
'chunk_tensor'
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
body
[
-
1
]
=
' '
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
)
delete_unused_value_func
(
node
,
body
,
to_keep
)
else
:
emit_node_func
(
node
,
body
)
delete_unused_value_func
(
node
,
body
)
if
node_idx
not
in
chunk_inputs
:
delete_unused_value_func
(
node
,
body
,
to_keep
)
if
node_idx
in
chunk_ends
:
body
.
append
(
_gen_loop_5_final
(
node
.
name
,
to_keep
))
to_keep
=
[]
within_chunk_region
=
False
node_idx
+=
1
...
...
@@ -572,7 +568,7 @@ if CODEGEN_AVAILABLE:
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
# NOTE: we add a variable to distinguish body and ckpt_func
def
delete_unused_values
(
user
:
Node
,
body
):
def
delete_unused_values
(
user
:
Node
,
body
,
to_keep
=
[]
):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
...
...
@@ -584,6 +580,9 @@ if CODEGEN_AVAILABLE:
body
.
append
(
'
\n
'
)
return
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
for
n
in
nodes_to_delete
:
if
n
.
name
in
to_keep
:
nodes_to_delete
.
remove
(
n
)
if
len
(
nodes_to_delete
):
to_delete_str
=
' = '
.
join
([
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
'None'
])
body
.
append
(
f
';
{
to_delete_str
}
\n
'
)
...
...
@@ -693,5 +692,6 @@ if CODEGEN_AVAILABLE:
{
wrap_stmts
}
{
prologue
}
{
code
}
"""
{
code
}
"""
print
(
fn_code
)
return
PythonCode
(
fn_code
,
globals_
)
chunk_codegen_run.py
View file @
78cfe436
...
...
@@ -54,6 +54,7 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T
# test forward
non_fx_out
=
model
(
data
)
fx_out
=
gm
(
data
)
print
(
non_fx_out
.
shape
,
fx_out
.
shape
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
),
"fx_out doesn't comply with original output"
# test barckward
...
...
@@ -86,13 +87,13 @@ def _run_offload_codegen(rank):
setattr
(
node
,
"activation_offload"
,
[
0
,
True
,
False
])
if
node
.
name
==
"linear1"
:
setattr
(
node
,
"activation_offload"
,
[
0
,
True
,
False
])
if
node
.
name
==
"linear2"
:
setattr
(
node
,
"activation_offload"
,
[
1
,
True
,
True
])
if
node
.
name
==
"linear4"
:
setattr
(
node
,
"activation_offload"
,
[
2
,
False
,
True
])
if
node
.
name
==
"linear5"
:
setattr
(
node
,
"activation_checkpoint"
,
[
0
])
setattr
(
node
,
"activation_offload"
,
True
)
#
if node.name == "linear2":
#
setattr(node, "activation_offload", [1, True, True])
#
if node.name == "linear4":
#
setattr(node, "activation_offload", [2, False, True])
#
if node.name == "linear5":
#
setattr(node, "activation_checkpoint", [0])
#
setattr(node, "activation_offload", True)
gm
=
ColoGraphModule
(
copy
.
deepcopy
(
model
),
graph
)
gm
.
recompile
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment