Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1a6d2a74
Commit
1a6d2a74
authored
Jan 06, 2023
by
oahzxl
Browse files
take apart chunk code gen
parent
d1f07731
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
2408 additions
and
6 deletions
+2408
-6
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+497
-0
colossalai/autochunk/chunk_region_search.py
colossalai/autochunk/chunk_region_search.py
+211
-0
colossalai/autochunk/chunk_selector.py
colossalai/autochunk/chunk_selector.py
+221
-0
colossalai/autochunk/index_tracer.py
colossalai/autochunk/index_tracer.py
+1056
-0
colossalai/autochunk/memory_estiamtor.py
colossalai/autochunk/memory_estiamtor.py
+318
-0
colossalai/autochunk/utils.py
colossalai/autochunk/utils.py
+95
-0
tests/test_autochunk/benchmark_autochunk.py
tests/test_autochunk/benchmark_autochunk.py
+8
-4
tests/test_autochunk/test_autochunk.py
tests/test_autochunk/test_autochunk.py
+2
-2
No files found.
colossalai/autochunk/autochunk_codegen.py
0 → 100644
View file @
1a6d2a74
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
import
torch
from
torch.fx.graph
import
(
CodeGen
,
PythonCode
,
_custom_builtins
,
_CustomBuiltin
,
_format_target
,
_is_from_torch
,
_Namespace
,
_origin_type_map
,
inplace_methods
,
magic_methods
,
)
from
torch.fx.node
import
Argument
,
Node
,
_get_qualified_name
,
_type_repr
,
map_arg
import
colossalai
from
.chunk_region_search
import
ChunkRegionSearch
from
.utils
import
delete_free_var_from_last_use
,
find_idx_by_name
,
get_node_shape
CODEGEN_AVAILABLE
=
True
__all__
=
[
"AutoChunkCodeGen"
]
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
new_shape
=
"["
for
idx
,
i
in
enumerate
(
shape
):
if
idx
==
chunk_dim
:
new_shape
+=
"%s:%s + chunk_size"
%
(
chunk_idx_name
,
chunk_idx_name
)
else
:
new_shape
+=
":"
new_shape
+=
", "
new_shape
=
new_shape
[:
-
2
]
+
"]"
return
new_shape
def
_gen_loop_start
(
chunk_input
,
chunk_output
,
chunk_ouput_dim
,
chunk_size
=
2
):
input_node
=
chunk_input
[
0
]
out_shape
=
get_node_shape
(
chunk_output
)
out_str
=
str
(
list
(
out_shape
))
context
=
(
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d
\n
for chunk_idx in range"
%
(
out_str
,
input_node
.
name
,
input_node
.
name
,
chunk_size
)
)
context
+=
"(0, %d, chunk_size):
\n
"
%
(
out_shape
[
chunk_ouput_dim
])
return
context
def
_gen_loop_end
(
chunk_inputs
,
chunk_non_compute_inputs
,
chunk_outputs
,
chunk_outputs_dim
,
node_list
):
chunk_outputs_name
=
chunk_outputs
.
name
chunk_outputs_idx
=
find_idx_by_name
(
chunk_outputs_name
,
node_list
)
chunk_output_shape
=
chunk_outputs
.
meta
[
"tensor_meta"
].
shape
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
)
context
=
" chunk_result%s = %s; %s = None
\n
"
%
(
chunk_slice
,
chunk_outputs_name
,
chunk_outputs_name
,
)
context
+=
(
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
)
# determine if its the last use for chunk input
for
chunk_input
in
chunk_inputs
+
chunk_non_compute_inputs
:
if
all
(
[
find_idx_by_name
(
user
.
name
,
node_list
)
<=
chunk_outputs_idx
for
user
in
chunk_input
.
users
.
keys
()
]
):
context
+=
"; %s = None"
%
chunk_input
.
name
context
+=
"
\n
"
return
context
def
_replace_name
(
context
,
name_from
,
name_to
):
patterns
=
[(
" "
,
" "
),
(
" "
,
"."
),
(
" "
,
","
),
(
"("
,
")"
),
(
"("
,
","
),
(
" "
,
")"
)]
for
p
in
patterns
:
source
=
p
[
0
]
+
name_from
+
p
[
1
]
target
=
p
[
0
]
+
name_to
+
p
[
1
]
if
source
in
context
:
context
=
context
.
replace
(
source
,
target
)
return
context
def
_replace_reshape_size
(
context
,
node_name
,
reshape_size_dict
):
if
node_name
not
in
reshape_size_dict
:
return
context
for
size_name
,
size_value
in
reshape_size_dict
[
node_name
].
items
():
context
=
context
.
replace
(
size_name
,
size_value
)
return
context
def
emit_code_with_chunk
(
body
,
nodes
,
emit_node_func
,
delete_unused_value_func
,
chunk_region_search
,
chunk_infos
,
):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
Args:
body: forward code
ckpt_func: checkpoint functions code
nodes: graph.nodes
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
"""
node_list
=
list
(
nodes
)
chunk_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
chunk_starts
=
[
i
[
0
]
for
i
in
chunk_regions
]
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_infos
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_outputs_dim
=
[
i
[
"outputs_dim"
]
for
i
in
chunk_infos
]
node_list
=
chunk_region_search
.
index_tracer
.
reorder_node_list
(
node_list
)
node_idx
=
0
region_idx
=
0
within_chunk_region
=
False
while
node_idx
<
len
(
node_list
):
node
=
node_list
[
node_idx
]
if
node_idx
in
chunk_starts
:
within_chunk_region
=
True
region_idx
=
chunk_starts
.
index
(
node_idx
)
body
.
append
(
_gen_loop_start
(
chunk_inputs
[
region_idx
],
chunk_outputs
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
chunk_infos
[
region_idx
][
"chunk_size"
],
)
)
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
# replace input var with chunk var
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
if
idx
==
node_idx
:
chunk_slice
=
_gen_chunk_slice_dim
(
dim
[
0
],
"chunk_idx"
,
get_node_shape
(
input_node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
# ones like
if
"ones_like"
in
node
.
name
:
meta_node
=
chunk_region_search
.
index_tracer
.
node_list
[
node_idx
]
chunk_dim
=
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
"chunk_dim"
]
if
get_node_shape
(
meta_node
)[
chunk_dim
]
!=
1
:
source_node
=
meta_node
.
args
[
0
].
args
[
0
]
if
(
source_node
not
in
chunk_infos
[
region_idx
][
"node_chunk_dim"
]
or
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
source_node
][
"chunk_dim"
]
is
None
):
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"chunk_idx"
,
get_node_shape
(
node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
)
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
],
node
.
name
,
chunk_infos
[
region_idx
][
"reshape_size"
]
)
body
[
-
1
]
=
" "
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
else
:
emit_node_func
(
node
,
body
)
if
node_idx
not
in
chunk_inputs
:
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
if
node_idx
in
chunk_ends
:
body
.
append
(
_gen_loop_end
(
chunk_inputs
[
region_idx
],
chunk_inputs_non_chunk
[
region_idx
],
chunk_outputs
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
node_list
,
)
)
within_chunk_region
=
False
node_idx
+=
1
if
CODEGEN_AVAILABLE
:
class
AutoChunkCodeGen
(
CodeGen
):
def
__init__
(
self
,
meta_graph
,
max_memory
=
None
):
super
().
__init__
()
self
.
meta_graph
=
meta_graph
self
.
max_memory
=
max_memory
self
.
meta_node
=
list
(
meta_graph
.
graph
.
nodes
)
# find the chunk regions
self
.
chunk_region_search
=
ChunkRegionSearch
(
meta_graph
,
max_memory
)
self
.
chunk_infos
=
self
.
chunk_region_search
.
search_region
()
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
free_vars
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
globals_
:
Dict
[
str
,
Any
]
=
{}
wrapped_fns
:
Dict
[
str
,
None
]
=
{}
# Wrap string in list to pass by reference
maybe_return_annotation
:
List
[
str
]
=
[
""
]
def
add_global
(
name_hint
:
str
,
obj
:
Any
):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if
(
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
):
# to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
return
_get_qualified_name
(
obj
)
# normalize the name hint to get a proper identifier
global_name
=
namespace
.
create_name
(
name_hint
,
obj
)
if
global_name
in
globals_
:
assert
globals_
[
global_name
]
is
obj
return
global_name
globals_
[
global_name
]
=
obj
return
global_name
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins
[
"colossalai"
]
=
_CustomBuiltin
(
"import colossalai"
,
colossalai
)
# Pre-fill the globals table with registered builtins.
for
name
,
(
_
,
obj
)
in
_custom_builtins
.
items
():
add_global
(
name
,
obj
)
def
type_repr
(
o
:
Any
):
if
o
==
():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return
"()"
typename
=
_type_repr
(
o
)
if
hasattr
(
o
,
"__origin__"
):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type
=
_origin_type_map
.
get
(
o
.
__origin__
,
o
.
__origin__
)
origin_typename
=
add_global
(
_type_repr
(
origin_type
),
origin_type
)
if
hasattr
(
o
,
"__args__"
):
# Assign global names for each of the inner type variables.
args
=
[
type_repr
(
arg
)
for
arg
in
o
.
__args__
]
if
len
(
args
)
==
0
:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python < 3.9
return
origin_typename
return
f
'
{
origin_typename
}
[
{
","
.
join
(
args
)
}
]'
else
:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python 3.9+
return
origin_typename
# Common case: this is a regular module name like 'foo.bar.baz'
return
add_global
(
typename
,
o
)
def
_format_args
(
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Argument
]
)
->
str
:
def
_get_repr
(
arg
):
# Handle NamedTuples (if it has `_fields`) via add_global.
if
isinstance
(
arg
,
tuple
)
and
hasattr
(
arg
,
"_fields"
):
qualified_name
=
_get_qualified_name
(
type
(
arg
))
global_name
=
add_global
(
qualified_name
,
type
(
arg
))
return
f
"
{
global_name
}{
repr
(
tuple
(
arg
))
}
"
return
repr
(
arg
)
args_s
=
", "
.
join
(
_get_repr
(
a
)
for
a
in
args
)
kwargs_s
=
", "
.
join
(
f
"
{
k
}
=
{
_get_repr
(
v
)
}
"
for
k
,
v
in
kwargs
.
items
())
if
args_s
and
kwargs_s
:
return
f
"
{
args_s
}
,
{
kwargs_s
}
"
return
args_s
or
kwargs_s
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
delete_free_var_from_last_use
(
user_to_last_uses
)
# NOTE: we add a variable to distinguish body and ckpt_func
def
delete_unused_values
(
user
:
Node
,
body
,
to_keep
=
[]):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if
user
.
op
==
"placeholder"
:
return
if
user
.
op
==
"output"
:
body
.
append
(
"
\n
"
)
return
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
nodes_to_delete
=
[
i
for
i
in
nodes_to_delete
if
i
.
name
not
in
to_keep
]
if
len
(
nodes_to_delete
):
to_delete_str
=
" = "
.
join
(
[
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
"None"
]
)
body
.
append
(
f
";
{
to_delete_str
}
\n
"
)
else
:
body
.
append
(
"
\n
"
)
# NOTE: we add a variable to distinguish body and ckpt_func
def
emit_node
(
node
:
Node
,
body
):
maybe_type_annotation
=
(
""
if
node
.
type
is
None
else
f
" :
{
type_repr
(
node
.
type
)
}
"
)
if
node
.
op
==
"placeholder"
:
assert
isinstance
(
node
.
target
,
str
)
maybe_default_arg
=
(
""
if
not
node
.
args
else
f
" =
{
repr
(
node
.
args
[
0
])
}
"
)
free_vars
.
append
(
f
"
{
node
.
target
}{
maybe_type_annotation
}{
maybe_default_arg
}
"
)
raw_name
=
node
.
target
.
replace
(
"*"
,
""
)
if
raw_name
!=
repr
(
node
):
body
.
append
(
f
"
{
repr
(
node
)
}
=
{
raw_name
}
\n
"
)
return
elif
node
.
op
==
"call_method"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
"
f
"(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)"
)
return
elif
node
.
op
==
"call_function"
:
assert
callable
(
node
.
target
)
# pretty print operators
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
magic_methods
):
assert
isinstance
(
node
.
args
,
tuple
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
"
)
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
inplace_methods
):
body
.
append
(
f
"
{
inplace_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
; "
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
repr
(
node
.
args
[
0
])
}
"
)
return
qualified_name
=
_get_qualified_name
(
node
.
target
)
global_name
=
add_global
(
qualified_name
,
node
.
target
)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if
(
global_name
==
"getattr"
and
isinstance
(
node
.
args
,
tuple
)
and
isinstance
(
node
.
args
[
1
],
str
)
and
node
.
args
[
1
].
isidentifier
()
and
len
(
node
.
args
)
==
2
):
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
args
[
1
])
}
"
)
return
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
global_name
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
if
node
.
meta
.
get
(
"is_wrapped"
,
False
):
wrapped_fns
.
setdefault
(
global_name
)
return
elif
node
.
op
==
"call_module"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
_format_target
(
root_module
,
node
.
target
)
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
return
elif
node
.
op
==
"get_attr"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
root_module
,
node
.
target
)
}
"
)
return
elif
node
.
op
==
"output"
:
if
node
.
type
is
not
None
:
maybe_return_annotation
[
0
]
=
f
" ->
{
type_repr
(
node
.
type
)
}
"
body
.
append
(
self
.
generate_output
(
node
.
args
[
0
]))
return
raise
NotImplementedError
(
f
"node:
{
node
.
op
}
{
node
.
target
}
"
)
# Modified for activation checkpointing
ckpt_func
=
[]
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk
(
body
,
nodes
,
emit_node
,
delete_unused_values
,
self
.
chunk_region_search
,
self
.
chunk_infos
,
)
if
len
(
body
)
==
0
:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body
.
append
(
"pass
\n
"
)
if
len
(
wrapped_fns
)
>
0
:
wrap_name
=
add_global
(
"wrap"
,
torch
.
fx
.
wrap
)
wrap_stmts
=
"
\n
"
.
join
(
[
f
'
{
wrap_name
}
("
{
name
}
")'
for
name
in
wrapped_fns
]
)
else
:
wrap_stmts
=
""
if
self
.
_body_transformer
:
body
=
self
.
_body_transformer
(
body
)
for
name
,
value
in
self
.
additional_globals
():
add_global
(
name
,
value
)
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
prologue
=
self
.
gen_fn_def
(
free_vars
,
maybe_return_annotation
[
0
])
prologue
=
""
.
join
(
ckpt_func
)
+
prologue
prologue
=
prologue
code
=
""
.
join
(
body
)
code
=
"
\n
"
.
join
(
" "
+
line
for
line
in
code
.
split
(
"
\n
"
))
fn_code
=
f
"""
{
wrap_stmts
}
{
prologue
}
{
code
}
"""
# print(fn_code)
return
PythonCode
(
fn_code
,
globals_
)
colossalai/autochunk/chunk_region_search.py
0 → 100644
View file @
1a6d2a74
from
.index_tracer
import
IndexTracer
from
.memory_estiamtor
import
MemoryEstimator
from
.chunk_selector
import
ChunkSelector
import
copy
from
.utils
import
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
get_node_shape
class
ChunkRegionSearch
(
object
):
def
__init__
(
self
,
gm
,
max_memory
=
None
)
->
None
:
self
.
gm
=
gm
self
.
index_tracer
=
IndexTracer
(
list
(
gm
.
graph
.
nodes
))
self
.
index_tracer
.
trace_index
()
self
.
memory_estimator
=
MemoryEstimator
(
self
.
index_tracer
)
self
.
chunk_selector
=
ChunkSelector
(
self
.
index_tracer
,
self
.
memory_estimator
,
max_memory
=
max_memory
)
def
_find_peak_node
(
self
,
mem_peak
):
max_value
=
max
(
mem_peak
)
max_idx
=
mem_peak
.
index
(
max_value
)
return
max_idx
def
_get_free_var
(
self
):
free_var_idx
=
[]
for
idx
,
n
in
enumerate
(
self
.
index_tracer
.
node_list
):
if
n
.
op
==
"placeholder"
:
free_var_idx
.
append
(
idx
)
return
free_var_idx
def
_get_min_free_var
(
self
,
active_node_list
,
free_vars
):
min_len
=
999
for
idx
,
n
in
enumerate
(
active_node_list
):
if
idx
in
free_vars
:
continue
if
len
(
n
)
<
min_len
:
min_len
=
len
(
n
)
return
min_len
def
_search_max_chunk_region
(
self
,
active_node
,
peak_node
,
chunk_regions
):
free_vars
=
self
.
_get_free_var
()
free_var_num
=
len
(
free_vars
)
active_node_num
=
[
len
(
i
)
for
i
in
active_node
]
min_active_node_num
=
min
(
active_node_num
[
free_var_num
:])
threshold
=
max
(
free_var_num
,
min_active_node_num
)
# from peak_node to free_var
inside_flag
=
False
chunk_region_start
=
free_var_num
for
i
in
range
(
peak_node
,
-
1
,
-
1
):
if
active_node_num
[
i
]
<=
threshold
:
inside_flag
=
True
if
inside_flag
and
active_node_num
[
i
]
>
threshold
:
chunk_region_start
=
i
+
1
break
# from peak_node to len-2
inside_flag
=
False
chunk_region_end
=
len
(
active_node
)
-
1
for
i
in
range
(
peak_node
,
len
(
active_node
)):
if
active_node_num
[
i
]
<=
threshold
:
inside_flag
=
True
if
inside_flag
and
active_node_num
[
i
]
>
threshold
:
chunk_region_end
=
i
break
for
i
in
chunk_regions
:
region
=
i
[
"region"
]
if
chunk_region_start
>=
region
[
0
]
and
chunk_region_end
<=
region
[
1
]:
return
None
elif
(
region
[
0
]
<=
chunk_region_start
<=
region
[
1
]
and
chunk_region_end
>
region
[
1
]
):
chunk_region_start
=
region
[
1
]
+
1
elif
(
region
[
0
]
<=
chunk_region_end
<=
region
[
1
]
and
chunk_region_start
<
region
[
0
]
):
chunk_region_end
=
region
[
0
]
-
1
return
chunk_region_start
,
chunk_region_end
def
_is_not_compute
(
self
,
trace
,
chunk_range
,
dim_idx
):
if
trace
[
"idx"
][
dim_idx
]
not
in
trace
[
"compute"
]:
return
True
if
trace
[
"idx"
][
dim_idx
]
in
trace
[
"compute"
]
and
all
(
i
<
chunk_range
[
0
]
or
i
>
chunk_range
[
1
]
for
i
in
trace
[
"compute"
][
trace
[
"idx"
][
dim_idx
]]
):
return
True
return
False
def
_find_free_dim
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
):
start_traces
=
input_trace
[
start_idx
]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
index_tracer
.
node_list
[
end_idx
]
chunk_infos
=
[]
for
end_dim
,
_
in
enumerate
(
end_trace
[
"idx"
]):
if
len
(
start_traces
)
>
1
:
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
_
in
enumerate
(
start_trace
[
"idx"
]):
# dim size cannot be 1
if
(
get_node_shape
(
end_node
)[
end_dim
]
==
1
or
get_node_shape
(
start_node
)[
start_dim
]
==
1
):
continue
# check index source align
if
not
self
.
index_tracer
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
continue
# check index copmute
if
not
self
.
index_tracer
.
check_index_compute
(
start_idx
,
end_dim
,
end_node
,
end_idx
):
continue
# flow search
chunk_info
=
self
.
index_tracer
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
)
if
chunk_info
is
None
:
continue
# check index copmute
if
not
self
.
index_tracer
.
check_index_duplicate
(
chunk_info
):
continue
chunk_infos
.
append
(
chunk_info
)
return
chunk_infos
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
,
peak_node
):
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
index_tracer
.
idx_trace_list
)
input_trace
=
[]
# trace of a node's input nodes
for
_
,
n
in
enumerate
(
self
.
index_tracer
.
node_list
):
cur_trace
=
{}
for
arg
in
n
.
args
:
if
type
(
arg
)
==
type
(
n
)
and
not
is_non_compute_node_except_placeholder
(
arg
):
cur_trace
[
arg
]
=
self
.
index_tracer
.
_find_trace_from_node
(
arg
)
input_trace
.
append
(
cur_trace
)
for
start_idx
in
range
(
max_chunk_region
[
0
],
peak_node
+
1
):
for
end_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
# skip non compute nodes
if
is_non_compute_node
(
self
.
index_tracer
.
node_list
[
start_idx
]
)
or
is_non_compute_node
(
self
.
index_tracer
.
node_list
[
end_idx
]):
continue
# select free dim
chunk_info
=
self
.
_find_free_dim
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
if
len
(
chunk_info
)
>
0
:
possible_chunk_region
.
extend
(
chunk_info
)
return
possible_chunk_region
def
_step_search
(
self
,
mem_peak
,
active_node
,
chunk_regions
):
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
max_chunk_region
=
self
.
_search_max_chunk_region
(
active_node
,
peak_node
,
chunk_regions
)
if
max_chunk_region
==
None
:
return
None
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
)
best_chunk_region
=
self
.
chunk_selector
.
_select_best_chunk_region
(
possible_chunk_regions
,
chunk_regions
,
peak_node
,
max_chunk_region
,
mem_peak
)
best_chunk_region
=
self
.
index_tracer
.
reorder_all
(
best_chunk_region
)
return
best_chunk_region
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
sorted_init_mem_peak
=
sorted
(
init_mem_peak
)
if
max
(
mem_peak
)
<
sorted_init_mem_peak
[
int
(
len
(
sorted_init_mem_peak
)
*
0.5
)]:
return
True
return
False
def
search_region
(
self
):
chunk_infos
=
[]
(
init_mem_peak
,
_
,
active_node
,
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
)
mem_peak
=
init_mem_peak
while
True
:
chunk_info
=
self
.
_step_search
(
mem_peak
,
active_node
,
chunk_infos
)
if
chunk_info
is
None
:
break
chunk_infos
.
append
(
chunk_info
)
(
mem_peak
,
_
,
active_node
,
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
chunk_infos
)
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
chunk_infos
,
print_mem
=
True
)
return
chunk_infos
colossalai/autochunk/chunk_selector.py
0 → 100644
View file @
1a6d2a74
from
.index_tracer
import
IndexTracer
from
.memory_estiamtor
import
MemoryEstimator
from
.utils
import
is_non_compute_node
class
ChunkSelector
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
,
memory_estimator
:
MemoryEstimator
,
max_memory
=
None
,
):
self
.
index_tracer
=
index_tracer
self
.
memory_estimator
=
memory_estimator
if
max_memory
is
not
None
:
self
.
stratge
=
"fit_memory"
self
.
max_memory
=
max_memory
# MB
else
:
self
.
stratge
=
"min_memory"
def
_select_best_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
if
self
.
stratge
==
"min_memory"
:
best_region
=
self
.
_select_min_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
,
)
elif
self
.
stratge
==
"fit_memory"
:
best_region
=
self
.
_select_fit_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
,
)
else
:
raise
RuntimeError
()
return
best_region
def
_select_fit_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
# stop chunk if max memory satisfy memory limit
if
max
(
mem_peak
)
<
self
.
max_memory
:
return
None
# remove illegal regions
illegal_regions
=
[]
for
i
in
possible_chunk_regions
:
if
not
self
.
_is_legal_region
(
i
,
chunk_infos
):
illegal_regions
.
append
(
i
)
for
i
in
illegal_regions
:
if
i
in
possible_chunk_regions
:
possible_chunk_regions
.
remove
(
i
)
if
len
(
possible_chunk_regions
)
==
0
:
return
None
# get mem for chunk region
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
index_tracer
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
if
cur_chunk_region_max_peak
<
self
.
max_memory
:
regions_dict
.
append
(
{
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
,
}
)
# no region found
if
len
(
regions_dict
)
==
0
:
raise
RuntimeError
(
"Search failed. Try a larger memory threshold."
)
# select the min chunk len
chunk_len
=
[
i
[
"chunk_len"
]
for
i
in
regions_dict
]
best_region_idx
=
chunk_len
.
index
(
min
(
chunk_len
))
best_region
=
regions_dict
[
best_region_idx
]
# get max chunk size
best_region
=
self
.
_get_fit_chunk_size
(
best_region
,
chunk_infos
)
return
best_region
def
_get_fit_chunk_size
(
self
,
chunk_region_dict
,
chunk_infos
):
chunk_size
=
1
reorder_chunk_info
=
chunk_region_dict
[
"reorder_chunk_info"
]
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_max_mem
=
0
# search a region
while
cur_chunk_max_mem
<
self
.
max_memory
:
chunk_size
*=
2
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_infos
=
chunk_infos
+
[
reorder_chunk_info
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
reorder_chunk_info
[
"region"
][
0
]
:
reorder_chunk_info
[
"region"
][
1
]
+
1
]
)
# search exact size
chunk_info
=
chunk_region_dict
[
"chunk_info"
]
chunk_info
[
"chunk_size"
]
=
self
.
_chunk_size_binary_search
(
chunk_size
//
2
,
chunk_size
,
chunk_region_dict
,
chunk_infos
)
return
chunk_info
def
_chunk_size_binary_search
(
self
,
l
,
r
,
chunk_region_dict
,
chunk_infos
):
if
l
>=
16
:
gap
=
4
else
:
gap
=
1
chunk_info
=
chunk_region_dict
[
"reorder_chunk_info"
]
while
r
>=
l
+
gap
:
mid
=
int
((
l
+
r
)
/
2
+
0.5
)
chunk_info
[
"chunk_size"
]
=
mid
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]
:
chunk_info
[
"region"
][
1
]
+
1
]
)
if
cur_chunk_max_mem
>=
self
.
max_memory
:
r
=
mid
-
gap
else
:
l
=
mid
+
gap
return
l
def
_get_compute_node_num
(
self
,
start
,
end
):
count
=
0
for
i
in
self
.
index_tracer
.
node_list
[
start
:
end
+
1
]:
if
not
is_non_compute_node
(
i
):
count
+=
1
return
count
def
_select_min_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
# remove illegal regions
illegal_regions
=
[]
for
i
in
possible_chunk_regions
:
if
not
self
.
_is_legal_region
(
i
,
chunk_infos
):
illegal_regions
.
append
(
i
)
for
i
in
illegal_regions
:
if
i
in
possible_chunk_regions
:
possible_chunk_regions
.
remove
(
i
)
if
len
(
possible_chunk_regions
)
==
0
:
return
None
# get mem for chunk region
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
index_tracer
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
regions_dict
.
append
(
{
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
,
}
)
# select the min mem
chunk_max_mem
=
[
i
[
"chunk_max_mem"
]
for
i
in
regions_dict
]
best_region_idx
=
chunk_max_mem
.
index
(
min
(
chunk_max_mem
))
best_region
=
regions_dict
[
best_region_idx
][
"chunk_info"
]
if
best_region
is
not
None
:
best_region
[
"chunk_size"
]
=
1
return
best_region
def
_is_legal_region
(
self
,
cur_chunk_info
,
chunk_infos
):
(
chunk_region_start
,
chunk_region_end
)
=
cur_chunk_info
[
"region"
]
if
cur_chunk_info
in
chunk_infos
:
return
False
if
chunk_region_end
<
chunk_region_start
:
return
False
for
i
in
chunk_infos
:
region
=
i
[
"region"
]
if
not
(
(
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
or
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])
):
return
False
return
True
colossalai/autochunk/
chunk_codegen
.py
→
colossalai/autochunk/
index_tracer
.py
View file @
1a6d2a74
import
colossalai
import
torch
import
copy
import
copy
from
typing
import
List
,
Callable
,
Any
,
Tuple
,
Dict
,
Iterable
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
from
torch.fx.graph
import
(
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
CodeGen
,
_origin_type_map
,
inplace_methods
,
_CustomBuiltin
,
)
from
colossalai.fx.profiler
import
(
calculate_fwd_out
,
calculate_fwd_tmp
,
parameter_size
,
activation_size
,
)
CODEGEN_AVAILABLE
=
True
__all__
=
[
"ChunkCodeGen"
]
def
_delete_free_var_from_last_use
(
user_to_last_uses
):
for
key
,
value
in
user_to_last_uses
.
items
():
for
n
in
value
:
if
n
.
op
==
"placeholder"
:
user_to_last_uses
[
key
].
remove
(
n
)
def
_get_node_shape
(
node
):
if
hasattr
(
node
.
meta
[
"tensor_meta"
],
"shape"
):
return
node
.
meta
[
"tensor_meta"
].
shape
return
None
def
_is_non_compute_node
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
def
_is_non_compute_node_except_placeholder
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
from
.utils
import
(
def
_is_non_compute_node_except_placeholder_output
(
node
):
find_chunk_all_input_nodes
,
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
])
or
any
(
find_chunk_compute_input_and_output_nodes
,
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
find_idx_by_name
,
):
get_node_shape
,
return
True
is_non_compute_node
,
return
False
is_non_compute_node_except_placeholder
,
)
class
IndexTracer
(
object
):
class
IndexTracer
(
object
):
...
@@ -76,11 +22,11 @@ class IndexTracer(object):
...
@@ -76,11 +22,11 @@ class IndexTracer(object):
def
_init_idx_trace_list
(
self
):
def
_init_idx_trace_list
(
self
):
idx_trace_list
=
[]
idx_trace_list
=
[]
for
n
in
self
.
node_list
:
for
n
in
self
.
node_list
:
if
_
get_node_shape
(
n
)
!=
None
:
if
get_node_shape
(
n
)
!=
None
:
cur_trace
=
{
cur_trace
=
{
"idx"
:
[
None
for
_
in
range
(
len
(
_
get_node_shape
(
n
)))],
"idx"
:
[
None
for
_
in
range
(
len
(
get_node_shape
(
n
)))],
"compute"
:
[[]
for
_
in
range
(
len
(
_
get_node_shape
(
n
)))],
"compute"
:
[[]
for
_
in
range
(
len
(
get_node_shape
(
n
)))],
"source"
:
[{}
for
_
in
range
(
len
(
_
get_node_shape
(
n
)))],
"source"
:
[{}
for
_
in
range
(
len
(
get_node_shape
(
n
)))],
}
}
else
:
else
:
cur_trace
=
{
"idx"
:
[],
"compute"
:
[],
"source"
:
[]}
cur_trace
=
{
"idx"
:
[],
"compute"
:
[],
"source"
:
[]}
...
@@ -136,7 +82,7 @@ class IndexTracer(object):
...
@@ -136,7 +82,7 @@ class IndexTracer(object):
node_from_trace_source
=
self
.
_find_source_trace_from_node
(
node_from
)
node_from_trace_source
=
self
.
_find_source_trace_from_node
(
node_from
)
node_to_dim
=
self
.
_transform_index
(
node_to
,
node_to_dim
)
node_to_dim
=
self
.
_transform_index
(
node_to
,
node_to_dim
)
node_to_trace_source
=
self
.
_find_source_trace_from_node
(
node_to
)
node_to_trace_source
=
self
.
_find_source_trace_from_node
(
node_to
)
node_from_idx
=
_
find_idx_by_name
(
node_from
.
name
,
self
.
node_list
)
node_from_idx
=
find_idx_by_name
(
node_from
.
name
,
self
.
node_list
)
if
init
:
if
init
:
node_to_trace_source
[
node_to_dim
]
=
{}
node_to_trace_source
[
node_to_dim
]
=
{}
# add dim to cur new source
# add dim to cur new source
...
@@ -196,7 +142,7 @@ class IndexTracer(object):
...
@@ -196,7 +142,7 @@ class IndexTracer(object):
"""
"""
if
isinstance
(
dim
,
int
):
if
isinstance
(
dim
,
int
):
dim
=
[
dim
]
dim
=
[
dim
]
dims
=
list
(
range
(
len
(
_
get_node_shape
(
node
))))
dims
=
list
(
range
(
len
(
get_node_shape
(
node
))))
for
d
in
dim
:
for
d
in
dim
:
cur_dim
=
dims
[
d
]
cur_dim
=
dims
[
d
]
if
idx
not
in
self
.
idx_trace_list
[
idx
][
"compute"
][
cur_dim
]:
if
idx
not
in
self
.
idx_trace_list
[
idx
][
"compute"
][
cur_dim
]:
...
@@ -212,7 +158,7 @@ class IndexTracer(object):
...
@@ -212,7 +158,7 @@ class IndexTracer(object):
idx (list): idx of the node
idx (list): idx of the node
compute (list): computed idx of the node.
compute (list): computed idx of the node.
"""
"""
node_idx
=
_
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_dict
=
self
.
idx_trace_list
[
node_idx
]
node_dict
=
self
.
idx_trace_list
[
node_idx
]
return
node_dict
return
node_dict
...
@@ -226,7 +172,7 @@ class IndexTracer(object):
...
@@ -226,7 +172,7 @@ class IndexTracer(object):
idx (list): idx of the node
idx (list): idx of the node
compute (list): computed idx of the node.
compute (list): computed idx of the node.
"""
"""
node_idx
=
_
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_dict
=
self
.
idx_trace_list
[
node_idx
]
node_dict
=
self
.
idx_trace_list
[
node_idx
]
return
node_dict
[
"source"
]
return
node_dict
[
"source"
]
...
@@ -239,7 +185,7 @@ class IndexTracer(object):
...
@@ -239,7 +185,7 @@ class IndexTracer(object):
Returns:
Returns:
idx (list): idx of the node
idx (list): idx of the node
"""
"""
node_idx
=
_
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
return
self
.
idx_trace_list
[
node_idx
][
"idx"
]
return
self
.
idx_trace_list
[
node_idx
][
"idx"
]
def
_find_compute_trace_from_node
(
self
,
node
):
def
_find_compute_trace_from_node
(
self
,
node
):
...
@@ -251,7 +197,7 @@ class IndexTracer(object):
...
@@ -251,7 +197,7 @@ class IndexTracer(object):
Returns:
Returns:
compute (list): computed idx of the node.
compute (list): computed idx of the node.
"""
"""
node_idx
=
_
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
return
self
.
idx_trace_list
[
node_idx
][
"compute"
]
return
self
.
idx_trace_list
[
node_idx
][
"compute"
]
def
_assign_index_as_input
(
self
,
node
,
node_idx
,
input_node
=
None
):
def
_assign_index_as_input
(
self
,
node
,
node_idx
,
input_node
=
None
):
...
@@ -264,7 +210,7 @@ class IndexTracer(object):
...
@@ -264,7 +210,7 @@ class IndexTracer(object):
"""
"""
if
input_node
==
None
:
if
input_node
==
None
:
input_node
=
node
.
args
[
0
]
input_node
=
node
.
args
[
0
]
input_node_idx
=
_
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
input_node_idx_trace
=
self
.
idx_trace_list
[
input_node_idx
][
"idx"
]
input_node_idx_trace
=
self
.
idx_trace_list
[
input_node_idx
][
"idx"
]
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
...
@@ -359,7 +305,7 @@ class IndexTracer(object):
...
@@ -359,7 +305,7 @@ class IndexTracer(object):
"""
"""
matmul_left
,
matmul_right
=
node
.
args
matmul_left
,
matmul_right
=
node
.
args
assert
len
(
_
get_node_shape
(
matmul_left
))
==
len
(
_
get_node_shape
(
matmul_right
))
assert
len
(
get_node_shape
(
matmul_left
))
==
len
(
get_node_shape
(
matmul_right
))
self
.
_assign_index_as_input
(
node
,
node_idx
,
matmul_left
)
self
.
_assign_index_as_input
(
node
,
node_idx
,
matmul_left
)
self
.
_inherit_index
(
matmul_right
,
-
1
,
node
,
-
1
)
self
.
_inherit_index
(
matmul_right
,
-
1
,
node
,
-
1
)
...
@@ -398,8 +344,8 @@ class IndexTracer(object):
...
@@ -398,8 +344,8 @@ class IndexTracer(object):
self
.
_mark_computation_from_node
(
node_in
,
node
)
self
.
_mark_computation_from_node
(
node_in
,
node
)
assert
len
(
nodes_in
)
<=
2
assert
len
(
nodes_in
)
<=
2
if
len
(
nodes_in
)
==
2
:
if
len
(
nodes_in
)
==
2
:
node_in0_shape
=
_
get_node_shape
(
nodes_in
[
0
])
node_in0_shape
=
get_node_shape
(
nodes_in
[
0
])
node_in1_shape
=
_
get_node_shape
(
nodes_in
[
1
])
node_in1_shape
=
get_node_shape
(
nodes_in
[
1
])
for
i
in
range
(
-
1
,
-
min
(
len
(
node_in0_shape
),
len
(
node_in1_shape
))
-
1
,
-
1
):
for
i
in
range
(
-
1
,
-
min
(
len
(
node_in0_shape
),
len
(
node_in1_shape
))
-
1
,
-
1
):
if
node_in0_shape
[
i
]
==
node_in1_shape
[
i
]:
if
node_in0_shape
[
i
]
==
node_in1_shape
[
i
]:
self
.
_mark_idx_equal
(
nodes_in
[
0
],
i
,
nodes_in
[
1
],
i
)
self
.
_mark_idx_equal
(
nodes_in
[
0
],
i
,
nodes_in
[
1
],
i
)
...
@@ -657,7 +603,7 @@ class IndexTracer(object):
...
@@ -657,7 +603,7 @@ class IndexTracer(object):
Returns:
Returns:
bool: True if check pass
bool: True if check pass
"""
"""
start_node_idx
=
_
find_idx_by_name
(
start_node
.
name
,
self
.
node_list
)
start_node_idx
=
find_idx_by_name
(
start_node
.
name
,
self
.
node_list
)
end_node_trace
=
self
.
_find_trace_from_node
(
end_node
)
end_node_trace
=
self
.
_find_trace_from_node
(
end_node
)
end_node_trace_source
=
end_node_trace
[
"source"
][
end_dim
]
end_node_trace_source
=
end_node_trace
[
"source"
][
end_dim
]
sorted_source
=
sorted
(
sorted_source
=
sorted
(
...
@@ -692,16 +638,16 @@ class IndexTracer(object):
...
@@ -692,16 +638,16 @@ class IndexTracer(object):
def
get_node_chunk_dim
(
self
,
node_from
,
node_from_dim
,
node_to
):
def
get_node_chunk_dim
(
self
,
node_from
,
node_from_dim
,
node_to
):
node_from_source
=
self
.
_find_source_trace_from_node
(
node_from
)
node_from_source
=
self
.
_find_source_trace_from_node
(
node_from
)
dim_source
=
node_from_source
[
node_from_dim
]
dim_source
=
node_from_source
[
node_from_dim
]
node_to_idx
=
_
find_idx_by_name
(
node_to
.
name
,
self
.
node_list
)
node_to_idx
=
find_idx_by_name
(
node_to
.
name
,
self
.
node_list
)
for
k
,
v
in
dim_source
.
items
():
for
k
,
v
in
dim_source
.
items
():
if
k
==
node_to_idx
:
if
k
==
node_to_idx
:
return
v
return
v
return
None
return
None
def
_find_inherit_dim
(
self
,
input_node
,
input_dim
,
node
):
def
_find_inherit_dim
(
self
,
input_node
,
input_dim
,
node
):
input_node_idx
=
_
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
node_trace_source
=
self
.
_find_source_trace_from_node
(
node
)
node_trace_source
=
self
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
_
get_node_shape
(
node
))):
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
if
(
if
(
input_node_idx
in
node_trace_source
[
node_dim
]
input_node_idx
in
node_trace_source
[
node_dim
]
and
input_dim
[
0
]
in
node_trace_source
[
node_dim
][
input_node_idx
]
and
input_dim
[
0
]
in
node_trace_source
[
node_dim
][
input_node_idx
]
...
@@ -720,12 +666,12 @@ class IndexTracer(object):
...
@@ -720,12 +666,12 @@ class IndexTracer(object):
for
node
in
self
.
node_list
[
for
node
in
self
.
node_list
[
chunk_infos
[
"region"
][
0
]
:
chunk_infos
[
"region"
][
1
]
+
1
chunk_infos
[
"region"
][
0
]
:
chunk_infos
[
"region"
][
1
]
+
1
]:
]:
if
_
is_non_compute_node_except_placeholder
(
node
):
if
is_non_compute_node_except_placeholder
(
node
):
continue
continue
count
=
0
count
=
0
duplicate_dims
=
[]
duplicate_dims
=
[]
node_trace_source
=
self
.
_find_source_trace_from_node
(
node
)
node_trace_source
=
self
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
_
get_node_shape
(
node
))):
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
duplicate_dim
=
[]
duplicate_dim
=
[]
duplicate_flag
=
False
duplicate_flag
=
False
dim_source
=
node_trace_source
[
node_dim
]
dim_source
=
node_trace_source
[
node_dim
]
...
@@ -760,7 +706,7 @@ class IndexTracer(object):
...
@@ -760,7 +706,7 @@ class IndexTracer(object):
all_node_info
,
all_node_info
,
next_node_list
,
next_node_list
,
):
):
arg_idx
=
_
find_idx_by_name
(
arg_node
.
name
,
self
.
node_list
)
arg_idx
=
find_idx_by_name
(
arg_node
.
name
,
self
.
node_list
)
# arg in chunk range or be inputs
# arg in chunk range or be inputs
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
return
True
return
True
...
@@ -800,7 +746,7 @@ class IndexTracer(object):
...
@@ -800,7 +746,7 @@ class IndexTracer(object):
return
True
return
True
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
inputs
,
outputs
=
_
find_chunk_compute_input_and_output_nodes
(
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
)
# only single ouput
# only single ouput
...
@@ -817,7 +763,7 @@ class IndexTracer(object):
...
@@ -817,7 +763,7 @@ class IndexTracer(object):
# get cur node info
# get cur node info
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
cur_node_idx
=
_
find_idx_by_name
(
cur_node
.
name
,
self
.
node_list
)
cur_node_idx
=
find_idx_by_name
(
cur_node
.
name
,
self
.
node_list
)
if
cur_node_chunk_dim
:
if
cur_node_chunk_dim
:
cur_node_compute
=
self
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_compute
=
self
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_source
=
self
.
_find_source_trace_from_node
(
cur_node
)
cur_node_source
=
self
.
_find_source_trace_from_node
(
cur_node
)
...
@@ -829,7 +775,7 @@ class IndexTracer(object):
...
@@ -829,7 +775,7 @@ class IndexTracer(object):
for
arg
in
cur_node
.
args
:
for
arg
in
cur_node
.
args
:
if
type
(
arg
)
!=
type
(
cur_node
):
if
type
(
arg
)
!=
type
(
cur_node
):
continue
continue
if
_
is_non_compute_node
(
arg
):
if
is_non_compute_node
(
arg
):
continue
continue
arg_list
.
append
(
arg
)
arg_list
.
append
(
arg
)
flow_flag
=
self
.
_assgin_single_node_flow
(
flow_flag
=
self
.
_assgin_single_node_flow
(
...
@@ -851,13 +797,13 @@ class IndexTracer(object):
...
@@ -851,13 +797,13 @@ class IndexTracer(object):
for
arg
in
arg_list
:
for
arg
in
arg_list
:
if
not
(
if
not
(
start_idx
start_idx
<=
_
find_idx_by_name
(
arg
.
name
,
self
.
node_list
)
<=
find_idx_by_name
(
arg
.
name
,
self
.
node_list
)
<
end_idx
<
end_idx
):
):
continue
continue
arg_chunk_dim
=
all_node_info
[
arg
][
"chunk_dim"
]
arg_chunk_dim
=
all_node_info
[
arg
][
"chunk_dim"
]
arg_fix_dim
=
all_node_info
[
arg
][
"fix_dim"
]
arg_fix_dim
=
all_node_info
[
arg
][
"fix_dim"
]
arg_shape
=
_
get_node_shape
(
arg
)
arg_shape
=
get_node_shape
(
arg
)
# add all dim as fix dim except chunk dim
# add all dim as fix dim except chunk dim
for
i
,
shape
in
enumerate
(
arg_shape
):
for
i
,
shape
in
enumerate
(
arg_shape
):
if
shape
!=
1
and
i
!=
cur_node_chunk_dim
:
if
shape
!=
1
and
i
!=
cur_node_chunk_dim
:
...
@@ -877,11 +823,11 @@ class IndexTracer(object):
...
@@ -877,11 +823,11 @@ class IndexTracer(object):
remove_inputs
=
[]
remove_inputs
=
[]
for
input_node
in
inputs
:
for
input_node
in
inputs
:
input_dict
=
{}
input_dict
=
{}
input_node_idx
=
_
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
for
user
in
input_node
.
users
.
keys
():
for
user
in
input_node
.
users
.
keys
():
if
_
is_non_compute_node
(
user
):
if
is_non_compute_node
(
user
):
continue
continue
user_idx
=
_
find_idx_by_name
(
user
.
name
,
self
.
node_list
)
user_idx
=
find_idx_by_name
(
user
.
name
,
self
.
node_list
)
if
start_idx
<=
user_idx
<=
end_idx
:
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
if
chunk_dim
is
not
None
:
if
chunk_dim
is
not
None
:
...
@@ -916,7 +862,7 @@ class IndexTracer(object):
...
@@ -916,7 +862,7 @@ class IndexTracer(object):
if
node_info
[
"chunk_dim"
]
is
None
:
if
node_info
[
"chunk_dim"
]
is
None
:
maybe_prepose_nodes
.
append
(
node
)
maybe_prepose_nodes
.
append
(
node
)
maybe_prepose_nodes
.
sort
(
maybe_prepose_nodes
.
sort
(
key
=
lambda
x
:
_
find_idx_by_name
(
x
.
name
,
self
.
node_list
),
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
node_list
),
reverse
=
True
,
reverse
=
True
,
)
# from last node to first node
)
# from last node to first node
prepose_nodes
=
[]
prepose_nodes
=
[]
...
@@ -941,7 +887,7 @@ class IndexTracer(object):
...
@@ -941,7 +887,7 @@ class IndexTracer(object):
# out of loop
# out of loop
if
not
(
if
not
(
start_idx
start_idx
<=
_
find_idx_by_name
(
<=
find_idx_by_name
(
cur_prepose_node_arg
.
name
,
self
.
node_list
cur_prepose_node_arg
.
name
,
self
.
node_list
)
)
<
end_idx
<
end_idx
...
@@ -969,7 +915,7 @@ class IndexTracer(object):
...
@@ -969,7 +915,7 @@ class IndexTracer(object):
if
n
in
maybe_prepose_nodes
:
if
n
in
maybe_prepose_nodes
:
maybe_prepose_nodes
.
remove
(
n
)
maybe_prepose_nodes
.
remove
(
n
)
# sort by index
# sort by index
prepose_nodes
.
sort
(
key
=
lambda
x
:
_
find_idx_by_name
(
x
.
name
,
self
.
node_list
))
prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
node_list
))
chunk_info
[
"args"
][
"prepose_nodes"
]
=
prepose_nodes
chunk_info
[
"args"
][
"prepose_nodes"
]
=
prepose_nodes
# we need to log input nodes to avoid deleteing them in the loop
# we need to log input nodes to avoid deleteing them in the loop
...
@@ -977,7 +923,7 @@ class IndexTracer(object):
...
@@ -977,7 +923,7 @@ class IndexTracer(object):
# also need to get some prepose node's arg out of non_chunk_inputs
# also need to get some prepose node's arg out of non_chunk_inputs
for
n
in
prepose_nodes
:
for
n
in
prepose_nodes
:
chunk_node_list
.
remove
(
n
)
chunk_node_list
.
remove
(
n
)
non_chunk_inputs
=
_
find_chunk_all_input_nodes
(
chunk_node_list
)
non_chunk_inputs
=
find_chunk_all_input_nodes
(
chunk_node_list
)
for
i
in
non_chunk_inputs
:
for
i
in
non_chunk_inputs
:
if
i
not
in
chunk_info
[
"inputs"
]:
if
i
not
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
...
@@ -990,7 +936,7 @@ class IndexTracer(object):
...
@@ -990,7 +936,7 @@ class IndexTracer(object):
def
_reassgin_reshape_size
(
self
,
chunk_info
):
def
_reassgin_reshape_size
(
self
,
chunk_info
):
chunk_region
=
chunk_info
[
"region"
]
chunk_region
=
chunk_info
[
"region"
]
reshape_size
=
{}
reshape_size
=
{}
chunk_shape
=
_
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]
chunk_info
[
"outputs_dim"
]
]
]
for
node
in
self
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
for
node
in
self
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
...
@@ -1016,7 +962,7 @@ class IndexTracer(object):
...
@@ -1016,7 +962,7 @@ class IndexTracer(object):
chunk_region_end
=
chunk_info
[
"region"
][
1
]
chunk_region_end
=
chunk_info
[
"region"
][
1
]
chunk_prepose_nodes
=
chunk_info
[
"args"
][
"prepose_nodes"
]
chunk_prepose_nodes
=
chunk_info
[
"args"
][
"prepose_nodes"
]
chunk_prepose_nodes_idx
=
[
chunk_prepose_nodes_idx
=
[
_
find_idx_by_name
(
i
.
name
,
self
.
node_list
)
for
i
in
chunk_prepose_nodes
find_idx_by_name
(
i
.
name
,
self
.
node_list
)
for
i
in
chunk_prepose_nodes
]
]
# put prepose nodes ahead
# put prepose nodes ahead
for
idx
,
n
in
enumerate
(
chunk_prepose_nodes
):
for
idx
,
n
in
enumerate
(
chunk_prepose_nodes
):
...
@@ -1026,7 +972,7 @@ class IndexTracer(object):
...
@@ -1026,7 +972,7 @@ class IndexTracer(object):
for
n
in
self
.
node_list
[
chunk_region_start
:
chunk_region_end
+
1
]:
for
n
in
self
.
node_list
[
chunk_region_start
:
chunk_region_end
+
1
]:
if
n
in
chunk_prepose_nodes
:
if
n
in
chunk_prepose_nodes
:
continue
continue
n_idx
=
_
find_idx_by_name
(
n
.
name
,
self
.
node_list
)
n_idx
=
find_idx_by_name
(
n
.
name
,
self
.
node_list
)
pos
=
sum
([
n_idx
<
i
for
i
in
chunk_prepose_nodes_idx
])
pos
=
sum
([
n_idx
<
i
for
i
in
chunk_prepose_nodes_idx
])
reorder_map
[
n_idx
]
=
n_idx
+
pos
reorder_map
[
n_idx
]
=
n_idx
+
pos
...
@@ -1108,1257 +1054,3 @@ class IndexTracer(object):
...
@@ -1108,1257 +1054,3 @@ class IndexTracer(object):
chunk_info
=
self
.
_reorder_chunk_info
(
chunk_info
,
reorder_map
)
chunk_info
=
self
.
_reorder_chunk_info
(
chunk_info
,
reorder_map
)
return
new_node_list
,
chunk_info
return
new_node_list
,
chunk_info
class
MemoryEstimator
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
)
->
None
:
pass
def
_get_meta_node_size
(
self
,
x
):
x
=
x
.
meta
[
"tensor_meta"
]
x
=
x
.
numel
*
torch
.
tensor
([],
dtype
=
x
.
dtype
).
element_size
()
return
x
def
_get_output_node
(
self
,
n
):
fwd_out
=
{
x
.
uuid
:
x
for
x
in
n
.
meta
[
"fwd_out"
]
if
isinstance
(
x
,
torch
.
Tensor
)
and
hasattr
(
x
,
"uuid"
)
}
out_size
=
activation_size
(
fwd_out
)
out_node
=
[
n
.
name
]
if
out_size
>
0
else
[]
# if any(i in n.name for i in ['transpose', 'permute', 'view']):
# out_size = 0
return
out_size
,
out_node
def
_get_output_node_size
(
self
,
n
):
return
self
.
_get_output_node
(
n
)[
0
]
def
_add_active_node
(
self
,
n
,
active_list
):
new_active
=
self
.
_get_output_node
(
n
)[
1
]
if
n
.
op
==
"placeholder"
:
new_active
.
append
(
n
.
name
)
for
i
in
new_active
:
if
i
not
in
active_list
:
active_list
.
append
(
i
)
def
_get_delete_node
(
self
,
user
,
user_to_last_uses
,
to_keep
=
None
):
delete_size
=
0
delete_node
=
[]
if
user
.
op
not
in
(
"output"
,):
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
to_keep
is
not
None
:
keep_list
=
[]
for
n
in
nodes_to_delete
:
if
n
.
name
in
to_keep
:
keep_list
.
append
(
n
)
for
n
in
keep_list
:
if
n
in
nodes_to_delete
:
nodes_to_delete
.
remove
(
n
)
if
len
(
nodes_to_delete
):
out_node
=
[
self
.
_get_output_node
(
i
)
for
i
in
nodes_to_delete
]
delete_size
=
sum
([
i
[
0
]
for
i
in
out_node
])
for
i
in
range
(
len
(
out_node
)):
if
out_node
[
i
][
0
]
>
0
:
delete_node
.
append
(
out_node
[
i
][
1
][
0
])
elif
nodes_to_delete
[
i
].
op
==
"placeholder"
:
delete_node
.
append
(
nodes_to_delete
[
i
].
name
)
# elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']):
# delete_node.append(nodes_to_delete[i].name)
return
delete_size
,
delete_node
def
_get_delete_node_size
(
self
,
user
,
user_to_last_uses
,
to_keep
):
return
self
.
_get_delete_node
(
user
,
user_to_last_uses
,
to_keep
)[
0
]
def
_remove_deactive_node
(
self
,
user
,
user_to_last_uses
,
active_list
):
delete_node
=
self
.
_get_delete_node
(
user
,
user_to_last_uses
)[
1
]
for
i
in
delete_node
:
if
i
in
active_list
:
active_list
.
remove
(
i
)
def
_get_chunk_inputs_size
(
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
nodes_to_delete
=
[]
for
chunk_input
in
chunk_inputs
+
chunk_inputs_non_chunk
:
chunk_input_users
=
chunk_input
.
users
.
keys
()
chunk_input_users_idx
=
[
_find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
if
all
(
i
<=
chunk_end_idx
for
i
in
chunk_input_users_idx
):
if
chunk_input
not
in
nodes_to_delete
:
nodes_to_delete
.
append
(
chunk_input
)
out_node
=
[
self
.
_get_output_node
(
i
)
for
i
in
nodes_to_delete
]
delete_size
=
sum
([
i
[
0
]
for
i
in
out_node
])
return
delete_size
def
_get_last_usr
(
self
,
nodes
):
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
return
user_to_last_uses
def
_get_contiguous_memory
(
self
,
node
,
not_contiguous_list
,
delete
=
False
):
mem
=
0
not_contiguous_ops
=
[
"permute"
]
inherit_contiguous_ops
=
[
"transpose"
,
"view"
]
if
node
.
op
==
"call_function"
and
any
(
n
in
node
.
name
for
n
in
[
"matmul"
,
"reshape"
]
):
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# matmul won't change origin tensor, but create a tmp copy
mem
+=
self
.
_get_output_node_size
(
n
)
elif
node
.
op
==
"call_module"
:
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# module will just make origin tensor to contiguous
if
delete
:
not_contiguous_list
.
remove
(
n
)
elif
node
.
op
==
"call_method"
and
any
(
i
in
node
.
name
for
i
in
not_contiguous_ops
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
return
mem
def
_get_chunk_ratio
(
self
,
node
,
chunk_node_dim
,
chunk_size
):
if
node
not
in
chunk_node_dim
:
return
1.0
node_shape
=
_get_node_shape
(
node
)
chunk_dim
=
chunk_node_dim
[
node
][
"chunk_dim"
]
if
chunk_dim
is
None
:
return
1.0
else
:
return
float
(
chunk_size
)
/
node_shape
[
chunk_dim
]
def
_get_chunk_delete_node_size
(
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
):
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
# return 0
if
user
.
op
in
(
"placeholder"
,
"output"
):
return
0
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
delete_size
=
0
for
n
in
nodes_to_delete
:
if
n
.
name
in
chunk_inputs_names
:
continue
delete_size
+=
self
.
_get_output_node_size
(
n
)
*
chunk_ratio
return
delete_size
def
_print_mem_log
(
self
,
log
,
nodes
,
title
=
None
):
if
title
:
print
(
title
)
for
idx
,
(
l
,
n
)
in
enumerate
(
zip
(
log
,
nodes
)):
print
(
"%s:%.2f
\t
"
%
(
n
.
name
,
l
),
end
=
""
)
if
(
idx
+
1
)
%
3
==
0
:
print
(
""
)
print
(
"
\n
"
)
def
_print_compute_op_mem_log
(
self
,
log
,
nodes
,
title
=
None
):
if
title
:
print
(
title
)
for
idx
,
(
l
,
n
)
in
enumerate
(
zip
(
log
,
nodes
)):
if
n
.
op
in
[
"placeholder"
,
"get_attr"
,
"output"
]:
continue
if
any
(
i
in
n
.
name
for
i
in
[
"getitem"
,
"getattr"
]):
continue
print
(
"%s:%.2f
\t
"
%
(
n
.
name
,
l
),
end
=
""
)
if
(
idx
+
1
)
%
3
==
0
:
print
(
""
)
print
(
"
\n
"
)
def
estimate_chunk_inference_mem
(
self
,
node_list
,
chunk_infos
=
None
,
print_mem
=
False
,
):
act_memory
=
0.0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
active_node_list
=
[]
active_node_list_log
=
[]
not_contiguous_list
=
[]
user_to_last_uses
=
self
.
_get_last_usr
(
node_list
)
user_to_last_uses_no_free_var
=
self
.
_get_last_usr
(
node_list
)
_delete_free_var_from_last_use
(
user_to_last_uses_no_free_var
)
use_chunk
=
True
if
chunk_infos
is
not
None
else
False
chunk_within
=
False
chunk_region_idx
=
None
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_inputs_names
=
[]
if
use_chunk
:
chunk_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
chunk_starts
=
[
i
[
0
]
for
i
in
chunk_regions
]
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_node_dim
=
[
i
[
"node_chunk_dim"
]
for
i
in
chunk_infos
]
chunk_sizes
=
[
i
[
"chunk_size"
]
if
"chunk_size"
in
i
else
1
for
i
in
chunk_infos
]
for
idx
,
node
in
enumerate
(
node_list
):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if
use_chunk
and
idx
in
chunk_starts
:
chunk_within
=
True
chunk_region_idx
=
chunk_starts
.
index
(
idx
)
act_memory
+=
self
.
_get_output_node_size
(
chunk_outputs
[
chunk_region_idx
]
)
/
(
1024
**
2
)
# determine chunk ratio for current node
if
chunk_within
:
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
chunk_node_dim
[
chunk_region_idx
],
chunk_sizes
[
chunk_region_idx
],
)
# if node is placeholder, just add the size of the node
if
node
.
op
==
"placeholder"
:
act_memory
+=
self
.
_get_meta_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
# skip output
elif
node
.
op
==
"output"
:
continue
# no change for non compute node
elif
_is_non_compute_node_except_placeholder
(
node
):
act_memory_peak_log
.
append
(
act_memory
)
# node is a compute op
# calculate tmp, output node and delete node memory
else
:
# forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
act_memory
+=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
+=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
*
chunk_ratio
/
(
1024
**
2
)
)
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
if
chunk_within
:
act_memory
-=
self
.
_get_chunk_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_ratio
,
chunk_inputs_names
,
)
/
(
1024
**
2
)
else
:
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_inputs_names
)
/
(
1024
**
2
)
# log active node, only effective without chunk
self
.
_add_active_node
(
node
,
active_node_list
)
self
.
_remove_deactive_node
(
node
,
user_to_last_uses
,
active_node_list
)
# if node in chunk end nodes, restore chunk settings
if
use_chunk
and
idx
in
chunk_ends
:
act_memory
-=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
-=
self
.
_get_chunk_inputs_size
(
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_non_chunk
[
chunk_region_idx
],
node_list
,
chunk_regions
[
chunk_region_idx
][
1
],
)
/
(
1024
**
2
)
chunk_within
=
False
chunk_ratio
=
1
chunk_region_idx
=
None
act_memory_after_node_log
.
append
(
act_memory
)
active_node_list_log
.
append
(
copy
.
deepcopy
(
active_node_list
))
if
print_mem
:
print
(
"with chunk"
if
use_chunk
else
"without chunk"
)
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
self
.
_print_compute_op_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
# self._print_compute_op_mem_log(
# act_memory_after_node_log, node_list, "after"
# )
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return
act_memory_peak_log
,
act_memory_after_node_log
,
active_node_list_log
class
ChunkSelector
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
,
memory_estimator
:
MemoryEstimator
,
max_memory
=
None
,
):
self
.
index_tracer
=
index_tracer
self
.
memory_estimator
=
memory_estimator
if
max_memory
is
not
None
:
self
.
stratge
=
"fit_memory"
self
.
max_memory
=
max_memory
# MB
else
:
self
.
stratge
=
"min_memory"
def
_select_best_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
if
self
.
stratge
==
"min_memory"
:
best_region
=
self
.
_select_min_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
,
)
elif
self
.
stratge
==
"fit_memory"
:
best_region
=
self
.
_select_fit_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
,
)
else
:
raise
RuntimeError
()
return
best_region
def
_select_fit_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
# stop chunk if max memory satisfy memory limit
if
max
(
mem_peak
)
<
self
.
max_memory
:
return
None
# remove illegal regions
illegal_regions
=
[]
for
i
in
possible_chunk_regions
:
if
not
self
.
_is_legal_region
(
i
,
chunk_infos
):
illegal_regions
.
append
(
i
)
for
i
in
illegal_regions
:
if
i
in
possible_chunk_regions
:
possible_chunk_regions
.
remove
(
i
)
if
len
(
possible_chunk_regions
)
==
0
:
return
None
# get mem for chunk region
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
index_tracer
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
if
cur_chunk_region_max_peak
<
self
.
max_memory
:
regions_dict
.
append
(
{
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
,
}
)
# no region found
if
len
(
regions_dict
)
==
0
:
raise
RuntimeError
(
"Search failed. Try a larger memory threshold."
)
# select the min chunk len
chunk_len
=
[
i
[
"chunk_len"
]
for
i
in
regions_dict
]
best_region_idx
=
chunk_len
.
index
(
min
(
chunk_len
))
best_region
=
regions_dict
[
best_region_idx
]
# get max chunk size
best_region
=
self
.
_get_fit_chunk_size
(
best_region
,
chunk_infos
)
return
best_region
def
_get_fit_chunk_size
(
self
,
chunk_region_dict
,
chunk_infos
):
chunk_size
=
1
reorder_chunk_info
=
chunk_region_dict
[
"reorder_chunk_info"
]
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_max_mem
=
0
# search a region
while
cur_chunk_max_mem
<
self
.
max_memory
:
chunk_size
*=
2
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_infos
=
chunk_infos
+
[
reorder_chunk_info
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
reorder_chunk_info
[
"region"
][
0
]
:
reorder_chunk_info
[
"region"
][
1
]
+
1
]
)
# search exact size
chunk_info
=
chunk_region_dict
[
"chunk_info"
]
chunk_info
[
"chunk_size"
]
=
self
.
_chunk_size_binary_search
(
chunk_size
//
2
,
chunk_size
,
chunk_region_dict
,
chunk_infos
)
return
chunk_info
def
_chunk_size_binary_search
(
self
,
l
,
r
,
chunk_region_dict
,
chunk_infos
):
if
l
>=
16
:
gap
=
4
else
:
gap
=
1
chunk_info
=
chunk_region_dict
[
"reorder_chunk_info"
]
while
r
>=
l
+
gap
:
mid
=
int
((
l
+
r
)
/
2
+
0.5
)
chunk_info
[
"chunk_size"
]
=
mid
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]
:
chunk_info
[
"region"
][
1
]
+
1
]
)
if
cur_chunk_max_mem
>=
self
.
max_memory
:
r
=
mid
-
gap
else
:
l
=
mid
+
gap
return
l
def
_get_compute_node_num
(
self
,
start
,
end
):
count
=
0
for
i
in
self
.
index_tracer
.
node_list
[
start
:
end
+
1
]:
if
not
_is_non_compute_node
(
i
):
count
+=
1
return
count
def
_select_min_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
# remove illegal regions
illegal_regions
=
[]
for
i
in
possible_chunk_regions
:
if
not
self
.
_is_legal_region
(
i
,
chunk_infos
):
illegal_regions
.
append
(
i
)
for
i
in
illegal_regions
:
if
i
in
possible_chunk_regions
:
possible_chunk_regions
.
remove
(
i
)
if
len
(
possible_chunk_regions
)
==
0
:
return
None
# get mem for chunk region
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
index_tracer
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
regions_dict
.
append
(
{
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
,
}
)
# select the min mem
chunk_max_mem
=
[
i
[
"chunk_max_mem"
]
for
i
in
regions_dict
]
best_region_idx
=
chunk_max_mem
.
index
(
min
(
chunk_max_mem
))
best_region
=
regions_dict
[
best_region_idx
][
"chunk_info"
]
if
best_region
is
not
None
:
best_region
[
"chunk_size"
]
=
1
return
best_region
def
_is_legal_region
(
self
,
cur_chunk_info
,
chunk_infos
):
(
chunk_region_start
,
chunk_region_end
)
=
cur_chunk_info
[
"region"
]
if
cur_chunk_info
in
chunk_infos
:
return
False
if
chunk_region_end
<
chunk_region_start
:
return
False
for
i
in
chunk_infos
:
region
=
i
[
"region"
]
if
not
(
(
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
or
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])
):
return
False
return
True
class
ChunkRegionSearch
(
object
):
def
__init__
(
self
,
gm
,
max_memory
=
None
)
->
None
:
self
.
gm
=
gm
self
.
index_tracer
=
IndexTracer
(
list
(
gm
.
graph
.
nodes
))
self
.
index_tracer
.
trace_index
()
self
.
memory_estimator
=
MemoryEstimator
(
self
.
index_tracer
)
self
.
chunk_selector
=
ChunkSelector
(
self
.
index_tracer
,
self
.
memory_estimator
,
max_memory
=
max_memory
)
def
_find_peak_node
(
self
,
mem_peak
):
max_value
=
max
(
mem_peak
)
max_idx
=
mem_peak
.
index
(
max_value
)
return
max_idx
def
_get_free_var
(
self
):
free_var_idx
=
[]
for
idx
,
n
in
enumerate
(
self
.
index_tracer
.
node_list
):
if
n
.
op
==
"placeholder"
:
free_var_idx
.
append
(
idx
)
return
free_var_idx
def
_get_min_free_var
(
self
,
active_node_list
,
free_vars
):
min_len
=
999
for
idx
,
n
in
enumerate
(
active_node_list
):
if
idx
in
free_vars
:
continue
if
len
(
n
)
<
min_len
:
min_len
=
len
(
n
)
return
min_len
def
_search_max_chunk_region
(
self
,
active_node
,
peak_node
,
chunk_regions
):
free_vars
=
self
.
_get_free_var
()
free_var_num
=
len
(
free_vars
)
active_node_num
=
[
len
(
i
)
for
i
in
active_node
]
min_active_node_num
=
min
(
active_node_num
[
free_var_num
:])
threshold
=
max
(
free_var_num
,
min_active_node_num
)
# from peak_node to free_var
inside_flag
=
False
chunk_region_start
=
free_var_num
for
i
in
range
(
peak_node
,
-
1
,
-
1
):
if
active_node_num
[
i
]
<=
threshold
:
inside_flag
=
True
if
inside_flag
and
active_node_num
[
i
]
>
threshold
:
chunk_region_start
=
i
+
1
break
# from peak_node to len-2
inside_flag
=
False
chunk_region_end
=
len
(
active_node
)
-
1
for
i
in
range
(
peak_node
,
len
(
active_node
)):
if
active_node_num
[
i
]
<=
threshold
:
inside_flag
=
True
if
inside_flag
and
active_node_num
[
i
]
>
threshold
:
chunk_region_end
=
i
break
for
i
in
chunk_regions
:
region
=
i
[
"region"
]
if
chunk_region_start
>=
region
[
0
]
and
chunk_region_end
<=
region
[
1
]:
return
None
elif
(
region
[
0
]
<=
chunk_region_start
<=
region
[
1
]
and
chunk_region_end
>
region
[
1
]
):
chunk_region_start
=
region
[
1
]
+
1
elif
(
region
[
0
]
<=
chunk_region_end
<=
region
[
1
]
and
chunk_region_start
<
region
[
0
]
):
chunk_region_end
=
region
[
0
]
-
1
return
chunk_region_start
,
chunk_region_end
def
_is_not_compute
(
self
,
trace
,
chunk_range
,
dim_idx
):
if
trace
[
"idx"
][
dim_idx
]
not
in
trace
[
"compute"
]:
return
True
if
trace
[
"idx"
][
dim_idx
]
in
trace
[
"compute"
]
and
all
(
i
<
chunk_range
[
0
]
or
i
>
chunk_range
[
1
]
for
i
in
trace
[
"compute"
][
trace
[
"idx"
][
dim_idx
]]
):
return
True
return
False
def
_find_free_dim
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
):
start_traces
=
input_trace
[
start_idx
]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
index_tracer
.
node_list
[
end_idx
]
chunk_infos
=
[]
for
end_dim
,
_
in
enumerate
(
end_trace
[
"idx"
]):
if
len
(
start_traces
)
>
1
:
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
_
in
enumerate
(
start_trace
[
"idx"
]):
# dim size cannot be 1
if
(
_get_node_shape
(
end_node
)[
end_dim
]
==
1
or
_get_node_shape
(
start_node
)[
start_dim
]
==
1
):
continue
# check index source align
if
not
self
.
index_tracer
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
continue
# check index copmute
if
not
self
.
index_tracer
.
check_index_compute
(
start_idx
,
end_dim
,
end_node
,
end_idx
):
continue
# flow search
chunk_info
=
self
.
index_tracer
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
)
if
chunk_info
is
None
:
continue
# check index copmute
if
not
self
.
index_tracer
.
check_index_duplicate
(
chunk_info
):
continue
chunk_infos
.
append
(
chunk_info
)
return
chunk_infos
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
,
peak_node
):
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
index_tracer
.
idx_trace_list
)
input_trace
=
[]
# trace of a node's input nodes
for
_
,
n
in
enumerate
(
self
.
index_tracer
.
node_list
):
cur_trace
=
{}
for
arg
in
n
.
args
:
if
type
(
arg
)
==
type
(
n
)
and
not
_is_non_compute_node_except_placeholder
(
arg
):
cur_trace
[
arg
]
=
self
.
index_tracer
.
_find_trace_from_node
(
arg
)
input_trace
.
append
(
cur_trace
)
for
start_idx
in
range
(
max_chunk_region
[
0
],
peak_node
+
1
):
for
end_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
# skip non compute nodes
if
_is_non_compute_node
(
self
.
index_tracer
.
node_list
[
start_idx
]
)
or
_is_non_compute_node
(
self
.
index_tracer
.
node_list
[
end_idx
]):
continue
# select free dim
chunk_info
=
self
.
_find_free_dim
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
if
len
(
chunk_info
)
>
0
:
possible_chunk_region
.
extend
(
chunk_info
)
return
possible_chunk_region
def
_step_search
(
self
,
mem_peak
,
active_node
,
chunk_regions
):
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
max_chunk_region
=
self
.
_search_max_chunk_region
(
active_node
,
peak_node
,
chunk_regions
)
if
max_chunk_region
==
None
:
return
None
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
)
best_chunk_region
=
self
.
chunk_selector
.
_select_best_chunk_region
(
possible_chunk_regions
,
chunk_regions
,
peak_node
,
max_chunk_region
,
mem_peak
)
best_chunk_region
=
self
.
index_tracer
.
reorder_all
(
best_chunk_region
)
return
best_chunk_region
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
sorted_init_mem_peak
=
sorted
(
init_mem_peak
)
if
max
(
mem_peak
)
<
sorted_init_mem_peak
[
int
(
len
(
sorted_init_mem_peak
)
*
0.5
)]:
return
True
return
False
def
search_region
(
self
):
chunk_infos
=
[]
(
init_mem_peak
,
_
,
active_node
,
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
)
mem_peak
=
init_mem_peak
while
True
:
chunk_info
=
self
.
_step_search
(
mem_peak
,
active_node
,
chunk_infos
)
if
chunk_info
is
None
:
break
chunk_infos
.
append
(
chunk_info
)
(
mem_peak
,
_
,
active_node
,
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
chunk_infos
)
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
chunk_infos
,
print_mem
=
True
)
return
chunk_infos
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
new_shape
=
"["
for
idx
,
i
in
enumerate
(
shape
):
if
idx
==
chunk_dim
:
new_shape
+=
"%s:%s + chunk_size"
%
(
chunk_idx_name
,
chunk_idx_name
)
else
:
new_shape
+=
":"
new_shape
+=
", "
new_shape
=
new_shape
[:
-
2
]
+
"]"
return
new_shape
def
_gen_loop_start
(
chunk_input
,
chunk_output
,
chunk_ouput_dim
,
chunk_size
=
2
):
input_node
=
chunk_input
[
0
]
out_shape
=
_get_node_shape
(
chunk_output
)
out_str
=
str
(
list
(
out_shape
))
context
=
(
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d
\n
for chunk_idx in range"
%
(
out_str
,
input_node
.
name
,
input_node
.
name
,
chunk_size
)
)
context
+=
"(0, %d, chunk_size):
\n
"
%
(
out_shape
[
chunk_ouput_dim
])
return
context
def
_gen_loop_end
(
chunk_inputs
,
chunk_non_compute_inputs
,
chunk_outputs
,
chunk_outputs_dim
,
node_list
):
chunk_outputs_name
=
chunk_outputs
.
name
chunk_outputs_idx
=
_find_idx_by_name
(
chunk_outputs_name
,
node_list
)
chunk_output_shape
=
chunk_outputs
.
meta
[
"tensor_meta"
].
shape
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
)
context
=
" chunk_result%s = %s; %s = None
\n
"
%
(
chunk_slice
,
chunk_outputs_name
,
chunk_outputs_name
,
)
context
+=
(
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
)
# determine if its the last use for chunk input
for
chunk_input
in
chunk_inputs
+
chunk_non_compute_inputs
:
if
all
(
[
_find_idx_by_name
(
user
.
name
,
node_list
)
<=
chunk_outputs_idx
for
user
in
chunk_input
.
users
.
keys
()
]
):
context
+=
"; %s = None"
%
chunk_input
.
name
context
+=
"
\n
"
return
context
def
_find_chunk_all_input_nodes
(
nodes
:
List
[
Node
]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes
=
[]
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
if
input_node
not
in
nodes
and
input_node
not
in
input_nodes
:
input_nodes
.
append
(
input_node
)
return
input_nodes
def
_find_chunk_compute_input_and_output_nodes
(
nodes
:
List
[
Node
]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes
=
[]
output_nodes
=
[]
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
if
(
input_node
not
in
nodes
and
input_node
not
in
input_nodes
and
not
_is_non_compute_node_except_placeholder
(
input_node
)
):
input_nodes
.
append
(
input_node
)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for
node
in
nodes
:
for
output_node
in
node
.
users
.
keys
():
if
(
output_node
not
in
nodes
and
node
not
in
output_nodes
and
not
_is_non_compute_node_except_placeholder_output
(
output_node
)
):
output_nodes
.
append
(
node
)
return
input_nodes
,
output_nodes
def
_find_idx_by_name
(
name
,
nodes_list
):
for
idx
,
node
in
enumerate
(
nodes_list
):
if
node
.
name
==
name
:
return
idx
raise
RuntimeError
(
"name %s not found in node list"
%
name
)
def
_replace_name
(
context
,
name_from
,
name_to
):
patterns
=
[(
" "
,
" "
),
(
" "
,
"."
),
(
" "
,
","
),
(
"("
,
")"
),
(
"("
,
","
),
(
" "
,
")"
)]
for
p
in
patterns
:
source
=
p
[
0
]
+
name_from
+
p
[
1
]
target
=
p
[
0
]
+
name_to
+
p
[
1
]
if
source
in
context
:
context
=
context
.
replace
(
source
,
target
)
return
context
def
_replace_reshape_size
(
context
,
node_name
,
reshape_size_dict
):
if
node_name
not
in
reshape_size_dict
:
return
context
for
size_name
,
size_value
in
reshape_size_dict
[
node_name
].
items
():
context
=
context
.
replace
(
size_name
,
size_value
)
return
context
def
emit_code_with_chunk
(
body
,
nodes
,
emit_node_func
,
delete_unused_value_func
,
chunk_region_search
,
chunk_infos
):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
Args:
body: forward code
ckpt_func: checkpoint functions code
nodes: graph.nodes
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
"""
node_list
=
list
(
nodes
)
chunk_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
chunk_starts
=
[
i
[
0
]
for
i
in
chunk_regions
]
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_infos
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_outputs_dim
=
[
i
[
"outputs_dim"
]
for
i
in
chunk_infos
]
node_list
=
chunk_region_search
.
index_tracer
.
reorder_node_list
(
node_list
)
node_idx
=
0
region_idx
=
0
within_chunk_region
=
False
while
node_idx
<
len
(
node_list
):
node
=
node_list
[
node_idx
]
if
node_idx
in
chunk_starts
:
within_chunk_region
=
True
region_idx
=
chunk_starts
.
index
(
node_idx
)
body
.
append
(
_gen_loop_start
(
chunk_inputs
[
region_idx
],
chunk_outputs
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
chunk_infos
[
region_idx
][
"chunk_size"
],
)
)
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
# replace input var with chunk var
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
if
idx
==
node_idx
:
chunk_slice
=
_gen_chunk_slice_dim
(
dim
[
0
],
"chunk_idx"
,
_get_node_shape
(
input_node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
# ones like
if
"ones_like"
in
node
.
name
:
meta_node
=
chunk_region_search
.
index_tracer
.
node_list
[
node_idx
]
chunk_dim
=
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
"chunk_dim"
]
if
_get_node_shape
(
meta_node
)[
chunk_dim
]
!=
1
:
source_node
=
meta_node
.
args
[
0
].
args
[
0
]
if
(
source_node
not
in
chunk_infos
[
region_idx
][
"node_chunk_dim"
]
or
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
source_node
][
"chunk_dim"
]
is
None
):
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"chunk_idx"
,
_get_node_shape
(
node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
)
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
],
node
.
name
,
chunk_infos
[
region_idx
][
"reshape_size"
]
)
body
[
-
1
]
=
" "
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
else
:
emit_node_func
(
node
,
body
)
if
node_idx
not
in
chunk_inputs
:
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
if
node_idx
in
chunk_ends
:
body
.
append
(
_gen_loop_end
(
chunk_inputs
[
region_idx
],
chunk_inputs_non_chunk
[
region_idx
],
chunk_outputs
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
node_list
,
)
)
within_chunk_region
=
False
node_idx
+=
1
if
CODEGEN_AVAILABLE
:
class
ChunkCodeGen
(
CodeGen
):
def
__init__
(
self
,
meta_graph
,
max_memory
=
None
):
super
().
__init__
()
self
.
meta_graph
=
meta_graph
self
.
max_memory
=
max_memory
self
.
meta_node
=
list
(
meta_graph
.
graph
.
nodes
)
# find the chunk regions
self
.
chunk_region_search
=
ChunkRegionSearch
(
meta_graph
,
max_memory
)
self
.
chunk_infos
=
self
.
chunk_region_search
.
search_region
()
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
free_vars
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
globals_
:
Dict
[
str
,
Any
]
=
{}
wrapped_fns
:
Dict
[
str
,
None
]
=
{}
# Wrap string in list to pass by reference
maybe_return_annotation
:
List
[
str
]
=
[
""
]
def
add_global
(
name_hint
:
str
,
obj
:
Any
):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if
(
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
):
# to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
return
_get_qualified_name
(
obj
)
# normalize the name hint to get a proper identifier
global_name
=
namespace
.
create_name
(
name_hint
,
obj
)
if
global_name
in
globals_
:
assert
globals_
[
global_name
]
is
obj
return
global_name
globals_
[
global_name
]
=
obj
return
global_name
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins
[
"colossalai"
]
=
_CustomBuiltin
(
"import colossalai"
,
colossalai
)
# Pre-fill the globals table with registered builtins.
for
name
,
(
_
,
obj
)
in
_custom_builtins
.
items
():
add_global
(
name
,
obj
)
def
type_repr
(
o
:
Any
):
if
o
==
():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return
"()"
typename
=
_type_repr
(
o
)
if
hasattr
(
o
,
"__origin__"
):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type
=
_origin_type_map
.
get
(
o
.
__origin__
,
o
.
__origin__
)
origin_typename
=
add_global
(
_type_repr
(
origin_type
),
origin_type
)
if
hasattr
(
o
,
"__args__"
):
# Assign global names for each of the inner type variables.
args
=
[
type_repr
(
arg
)
for
arg
in
o
.
__args__
]
if
len
(
args
)
==
0
:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python < 3.9
return
origin_typename
return
f
'
{
origin_typename
}
[
{
","
.
join
(
args
)
}
]'
else
:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python 3.9+
return
origin_typename
# Common case: this is a regular module name like 'foo.bar.baz'
return
add_global
(
typename
,
o
)
def
_format_args
(
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Argument
]
)
->
str
:
def
_get_repr
(
arg
):
# Handle NamedTuples (if it has `_fields`) via add_global.
if
isinstance
(
arg
,
tuple
)
and
hasattr
(
arg
,
"_fields"
):
qualified_name
=
_get_qualified_name
(
type
(
arg
))
global_name
=
add_global
(
qualified_name
,
type
(
arg
))
return
f
"
{
global_name
}{
repr
(
tuple
(
arg
))
}
"
return
repr
(
arg
)
args_s
=
", "
.
join
(
_get_repr
(
a
)
for
a
in
args
)
kwargs_s
=
", "
.
join
(
f
"
{
k
}
=
{
_get_repr
(
v
)
}
"
for
k
,
v
in
kwargs
.
items
())
if
args_s
and
kwargs_s
:
return
f
"
{
args_s
}
,
{
kwargs_s
}
"
return
args_s
or
kwargs_s
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
_delete_free_var_from_last_use
(
user_to_last_uses
)
# NOTE: we add a variable to distinguish body and ckpt_func
def
delete_unused_values
(
user
:
Node
,
body
,
to_keep
=
[]):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if
user
.
op
==
"placeholder"
:
return
if
user
.
op
==
"output"
:
body
.
append
(
"
\n
"
)
return
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
nodes_to_delete
=
[
i
for
i
in
nodes_to_delete
if
i
.
name
not
in
to_keep
]
if
len
(
nodes_to_delete
):
to_delete_str
=
" = "
.
join
(
[
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
"None"
]
)
body
.
append
(
f
";
{
to_delete_str
}
\n
"
)
else
:
body
.
append
(
"
\n
"
)
# NOTE: we add a variable to distinguish body and ckpt_func
def
emit_node
(
node
:
Node
,
body
):
maybe_type_annotation
=
(
""
if
node
.
type
is
None
else
f
" :
{
type_repr
(
node
.
type
)
}
"
)
if
node
.
op
==
"placeholder"
:
assert
isinstance
(
node
.
target
,
str
)
maybe_default_arg
=
(
""
if
not
node
.
args
else
f
" =
{
repr
(
node
.
args
[
0
])
}
"
)
free_vars
.
append
(
f
"
{
node
.
target
}{
maybe_type_annotation
}{
maybe_default_arg
}
"
)
raw_name
=
node
.
target
.
replace
(
"*"
,
""
)
if
raw_name
!=
repr
(
node
):
body
.
append
(
f
"
{
repr
(
node
)
}
=
{
raw_name
}
\n
"
)
return
elif
node
.
op
==
"call_method"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
"
f
"(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)"
)
return
elif
node
.
op
==
"call_function"
:
assert
callable
(
node
.
target
)
# pretty print operators
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
magic_methods
):
assert
isinstance
(
node
.
args
,
tuple
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
"
)
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
inplace_methods
):
body
.
append
(
f
"
{
inplace_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
; "
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
repr
(
node
.
args
[
0
])
}
"
)
return
qualified_name
=
_get_qualified_name
(
node
.
target
)
global_name
=
add_global
(
qualified_name
,
node
.
target
)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if
(
global_name
==
"getattr"
and
isinstance
(
node
.
args
,
tuple
)
and
isinstance
(
node
.
args
[
1
],
str
)
and
node
.
args
[
1
].
isidentifier
()
and
len
(
node
.
args
)
==
2
):
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
args
[
1
])
}
"
)
return
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
global_name
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
if
node
.
meta
.
get
(
"is_wrapped"
,
False
):
wrapped_fns
.
setdefault
(
global_name
)
return
elif
node
.
op
==
"call_module"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
_format_target
(
root_module
,
node
.
target
)
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
return
elif
node
.
op
==
"get_attr"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
root_module
,
node
.
target
)
}
"
)
return
elif
node
.
op
==
"output"
:
if
node
.
type
is
not
None
:
maybe_return_annotation
[
0
]
=
f
" ->
{
type_repr
(
node
.
type
)
}
"
body
.
append
(
self
.
generate_output
(
node
.
args
[
0
]))
return
raise
NotImplementedError
(
f
"node:
{
node
.
op
}
{
node
.
target
}
"
)
# Modified for activation checkpointing
ckpt_func
=
[]
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk
(
body
,
nodes
,
emit_node
,
delete_unused_values
,
self
.
chunk_region_search
,
self
.
chunk_infos
)
if
len
(
body
)
==
0
:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body
.
append
(
"pass
\n
"
)
if
len
(
wrapped_fns
)
>
0
:
wrap_name
=
add_global
(
"wrap"
,
torch
.
fx
.
wrap
)
wrap_stmts
=
"
\n
"
.
join
(
[
f
'
{
wrap_name
}
("
{
name
}
")'
for
name
in
wrapped_fns
]
)
else
:
wrap_stmts
=
""
if
self
.
_body_transformer
:
body
=
self
.
_body_transformer
(
body
)
for
name
,
value
in
self
.
additional_globals
():
add_global
(
name
,
value
)
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
prologue
=
self
.
gen_fn_def
(
free_vars
,
maybe_return_annotation
[
0
])
prologue
=
""
.
join
(
ckpt_func
)
+
prologue
prologue
=
prologue
code
=
""
.
join
(
body
)
code
=
"
\n
"
.
join
(
" "
+
line
for
line
in
code
.
split
(
"
\n
"
))
fn_code
=
f
"""
{
wrap_stmts
}
{
prologue
}
{
code
}
"""
# print(fn_code)
return
PythonCode
(
fn_code
,
globals_
)
colossalai/autochunk/memory_estiamtor.py
0 → 100644
View file @
1a6d2a74
import
copy
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
import
torch
from
torch.fx.node
import
Node
,
map_arg
from
colossalai.fx.profiler
import
activation_size
,
parameter_size
from
.index_tracer
import
IndexTracer
from
.utils
import
(
delete_free_var_from_last_use
,
find_idx_by_name
,
get_node_shape
,
is_non_compute_node_except_placeholder
,
)
class
MemoryEstimator
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
)
->
None
:
pass
def
_get_meta_node_size
(
self
,
x
):
x
=
x
.
meta
[
"tensor_meta"
]
x
=
x
.
numel
*
torch
.
tensor
([],
dtype
=
x
.
dtype
).
element_size
()
return
x
def
_get_output_node
(
self
,
n
):
fwd_out
=
{
x
.
uuid
:
x
for
x
in
n
.
meta
[
"fwd_out"
]
if
isinstance
(
x
,
torch
.
Tensor
)
and
hasattr
(
x
,
"uuid"
)
}
out_size
=
activation_size
(
fwd_out
)
out_node
=
[
n
.
name
]
if
out_size
>
0
else
[]
# if any(i in n.name for i in ['transpose', 'permute', 'view']):
# out_size = 0
return
out_size
,
out_node
def
_get_output_node_size
(
self
,
n
):
return
self
.
_get_output_node
(
n
)[
0
]
def
_add_active_node
(
self
,
n
,
active_list
):
new_active
=
self
.
_get_output_node
(
n
)[
1
]
if
n
.
op
==
"placeholder"
:
new_active
.
append
(
n
.
name
)
for
i
in
new_active
:
if
i
not
in
active_list
:
active_list
.
append
(
i
)
def
_get_delete_node
(
self
,
user
,
user_to_last_uses
,
to_keep
=
None
):
delete_size
=
0
delete_node
=
[]
if
user
.
op
not
in
(
"output"
,):
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
to_keep
is
not
None
:
keep_list
=
[]
for
n
in
nodes_to_delete
:
if
n
.
name
in
to_keep
:
keep_list
.
append
(
n
)
for
n
in
keep_list
:
if
n
in
nodes_to_delete
:
nodes_to_delete
.
remove
(
n
)
if
len
(
nodes_to_delete
):
out_node
=
[
self
.
_get_output_node
(
i
)
for
i
in
nodes_to_delete
]
delete_size
=
sum
([
i
[
0
]
for
i
in
out_node
])
for
i
in
range
(
len
(
out_node
)):
if
out_node
[
i
][
0
]
>
0
:
delete_node
.
append
(
out_node
[
i
][
1
][
0
])
elif
nodes_to_delete
[
i
].
op
==
"placeholder"
:
delete_node
.
append
(
nodes_to_delete
[
i
].
name
)
# elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']):
# delete_node.append(nodes_to_delete[i].name)
return
delete_size
,
delete_node
def
_get_delete_node_size
(
self
,
user
,
user_to_last_uses
,
to_keep
):
return
self
.
_get_delete_node
(
user
,
user_to_last_uses
,
to_keep
)[
0
]
def
_remove_deactive_node
(
self
,
user
,
user_to_last_uses
,
active_list
):
delete_node
=
self
.
_get_delete_node
(
user
,
user_to_last_uses
)[
1
]
for
i
in
delete_node
:
if
i
in
active_list
:
active_list
.
remove
(
i
)
def
_get_chunk_inputs_size
(
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
nodes_to_delete
=
[]
for
chunk_input
in
chunk_inputs
+
chunk_inputs_non_chunk
:
chunk_input_users
=
chunk_input
.
users
.
keys
()
chunk_input_users_idx
=
[
find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
if
all
(
i
<=
chunk_end_idx
for
i
in
chunk_input_users_idx
):
if
chunk_input
not
in
nodes_to_delete
:
nodes_to_delete
.
append
(
chunk_input
)
out_node
=
[
self
.
_get_output_node
(
i
)
for
i
in
nodes_to_delete
]
delete_size
=
sum
([
i
[
0
]
for
i
in
out_node
])
return
delete_size
def
_get_last_usr
(
self
,
nodes
):
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
return
user_to_last_uses
def
_get_contiguous_memory
(
self
,
node
,
not_contiguous_list
,
delete
=
False
):
mem
=
0
not_contiguous_ops
=
[
"permute"
]
inherit_contiguous_ops
=
[
"transpose"
,
"view"
]
if
node
.
op
==
"call_function"
and
any
(
n
in
node
.
name
for
n
in
[
"matmul"
,
"reshape"
]
):
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# matmul won't change origin tensor, but create a tmp copy
mem
+=
self
.
_get_output_node_size
(
n
)
elif
node
.
op
==
"call_module"
:
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# module will just make origin tensor to contiguous
if
delete
:
not_contiguous_list
.
remove
(
n
)
elif
node
.
op
==
"call_method"
and
any
(
i
in
node
.
name
for
i
in
not_contiguous_ops
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
return
mem
def
_get_chunk_ratio
(
self
,
node
,
chunk_node_dim
,
chunk_size
):
if
node
not
in
chunk_node_dim
:
return
1.0
node_shape
=
get_node_shape
(
node
)
chunk_dim
=
chunk_node_dim
[
node
][
"chunk_dim"
]
if
chunk_dim
is
None
:
return
1.0
else
:
return
float
(
chunk_size
)
/
node_shape
[
chunk_dim
]
def
_get_chunk_delete_node_size
(
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
):
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
# return 0
if
user
.
op
in
(
"placeholder"
,
"output"
):
return
0
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
delete_size
=
0
for
n
in
nodes_to_delete
:
if
n
.
name
in
chunk_inputs_names
:
continue
delete_size
+=
self
.
_get_output_node_size
(
n
)
*
chunk_ratio
return
delete_size
def
_print_mem_log
(
self
,
log
,
nodes
,
title
=
None
):
if
title
:
print
(
title
)
for
idx
,
(
l
,
n
)
in
enumerate
(
zip
(
log
,
nodes
)):
print
(
"%s:%.2f
\t
"
%
(
n
.
name
,
l
),
end
=
""
)
if
(
idx
+
1
)
%
3
==
0
:
print
(
""
)
print
(
"
\n
"
)
def
_print_compute_op_mem_log
(
self
,
log
,
nodes
,
title
=
None
):
if
title
:
print
(
title
)
for
idx
,
(
l
,
n
)
in
enumerate
(
zip
(
log
,
nodes
)):
if
n
.
op
in
[
"placeholder"
,
"get_attr"
,
"output"
]:
continue
if
any
(
i
in
n
.
name
for
i
in
[
"getitem"
,
"getattr"
]):
continue
print
(
"%s:%.2f
\t
"
%
(
n
.
name
,
l
),
end
=
""
)
if
(
idx
+
1
)
%
3
==
0
:
print
(
""
)
print
(
"
\n
"
)
def
estimate_chunk_inference_mem
(
self
,
node_list
,
chunk_infos
=
None
,
print_mem
=
False
,
):
act_memory
=
0.0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
active_node_list
=
[]
active_node_list_log
=
[]
not_contiguous_list
=
[]
user_to_last_uses
=
self
.
_get_last_usr
(
node_list
)
user_to_last_uses_no_free_var
=
self
.
_get_last_usr
(
node_list
)
delete_free_var_from_last_use
(
user_to_last_uses_no_free_var
)
use_chunk
=
True
if
chunk_infos
is
not
None
else
False
chunk_within
=
False
chunk_region_idx
=
None
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_inputs_names
=
[]
if
use_chunk
:
chunk_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
chunk_starts
=
[
i
[
0
]
for
i
in
chunk_regions
]
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_node_dim
=
[
i
[
"node_chunk_dim"
]
for
i
in
chunk_infos
]
chunk_sizes
=
[
i
[
"chunk_size"
]
if
"chunk_size"
in
i
else
1
for
i
in
chunk_infos
]
for
idx
,
node
in
enumerate
(
node_list
):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if
use_chunk
and
idx
in
chunk_starts
:
chunk_within
=
True
chunk_region_idx
=
chunk_starts
.
index
(
idx
)
act_memory
+=
self
.
_get_output_node_size
(
chunk_outputs
[
chunk_region_idx
]
)
/
(
1024
**
2
)
# determine chunk ratio for current node
if
chunk_within
:
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
chunk_node_dim
[
chunk_region_idx
],
chunk_sizes
[
chunk_region_idx
],
)
# if node is placeholder, just add the size of the node
if
node
.
op
==
"placeholder"
:
act_memory
+=
self
.
_get_meta_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
# skip output
elif
node
.
op
==
"output"
:
continue
# no change for non compute node
elif
is_non_compute_node_except_placeholder
(
node
):
act_memory_peak_log
.
append
(
act_memory
)
# node is a compute op
# calculate tmp, output node and delete node memory
else
:
# forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
act_memory
+=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
+=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
*
chunk_ratio
/
(
1024
**
2
)
)
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
if
chunk_within
:
act_memory
-=
self
.
_get_chunk_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_ratio
,
chunk_inputs_names
,
)
/
(
1024
**
2
)
else
:
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_inputs_names
)
/
(
1024
**
2
)
# log active node, only effective without chunk
self
.
_add_active_node
(
node
,
active_node_list
)
self
.
_remove_deactive_node
(
node
,
user_to_last_uses
,
active_node_list
)
# if node in chunk end nodes, restore chunk settings
if
use_chunk
and
idx
in
chunk_ends
:
act_memory
-=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
-=
self
.
_get_chunk_inputs_size
(
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_non_chunk
[
chunk_region_idx
],
node_list
,
chunk_regions
[
chunk_region_idx
][
1
],
)
/
(
1024
**
2
)
chunk_within
=
False
chunk_ratio
=
1
chunk_region_idx
=
None
act_memory_after_node_log
.
append
(
act_memory
)
active_node_list_log
.
append
(
copy
.
deepcopy
(
active_node_list
))
if
print_mem
:
print
(
"with chunk"
if
use_chunk
else
"without chunk"
)
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
self
.
_print_compute_op_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
# self._print_compute_op_mem_log(
# act_memory_after_node_log, node_list, "after"
# )
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return
act_memory_peak_log
,
act_memory_after_node_log
,
active_node_list_log
colossalai/autochunk/utils.py
0 → 100644
View file @
1a6d2a74
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
from
torch.fx.node
import
Node
def
is_non_compute_node
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
def
get_node_shape
(
node
):
if
hasattr
(
node
.
meta
[
"tensor_meta"
],
"shape"
):
return
node
.
meta
[
"tensor_meta"
].
shape
return
None
def
is_non_compute_node_except_placeholder
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
def
is_non_compute_node_except_placeholder_output
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
def
find_idx_by_name
(
name
,
nodes_list
):
for
idx
,
node
in
enumerate
(
nodes_list
):
if
node
.
name
==
name
:
return
idx
raise
RuntimeError
(
"name %s not found in node list"
%
name
)
def
delete_free_var_from_last_use
(
user_to_last_uses
):
for
key
,
value
in
user_to_last_uses
.
items
():
for
n
in
value
:
if
n
.
op
==
"placeholder"
:
user_to_last_uses
[
key
].
remove
(
n
)
def
find_chunk_all_input_nodes
(
nodes
:
List
[
Node
]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes
=
[]
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
if
input_node
not
in
nodes
and
input_node
not
in
input_nodes
:
input_nodes
.
append
(
input_node
)
return
input_nodes
def
find_chunk_compute_input_and_output_nodes
(
nodes
:
List
[
Node
]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes
=
[]
output_nodes
=
[]
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
if
(
input_node
not
in
nodes
and
input_node
not
in
input_nodes
and
not
is_non_compute_node_except_placeholder
(
input_node
)
):
input_nodes
.
append
(
input_node
)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for
node
in
nodes
:
for
output_node
in
node
.
users
.
keys
():
if
(
output_node
not
in
nodes
and
node
not
in
output_nodes
and
not
is_non_compute_node_except_placeholder_output
(
output_node
)
):
output_nodes
.
append
(
node
)
return
input_nodes
,
output_nodes
tests/test_autochunk/benchmark_autochunk.py
View file @
1a6d2a74
...
@@ -3,7 +3,7 @@ import time
...
@@ -3,7 +3,7 @@ import time
import
torch
import
torch
import
torch.fx
import
torch.fx
from
colossalai.autochunk.chunk_codegen
import
ChunkCodeGen
from
colossalai.autochunk.
auto
chunk_codegen
import
Auto
ChunkCodeGen
from
colossalai.fx
import
ColoTracer
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
...
@@ -49,25 +49,29 @@ def _build_autochunk(model, max_memory, node, pair):
...
@@ -49,25 +49,29 @@ def _build_autochunk(model, max_memory, node, pair):
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
},
},
)
)
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
interp
=
MetaInfoProp
(
gm_prop
)
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
)
# now run it twice to get meta info in graph module, not necessary
# now run it twice to get meta info in graph module, not necessary
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
)
# set code_gen
# set code_gen
codegen
=
ChunkCodeGen
(
gm_prop
,
max_memory
)
codegen
=
Auto
ChunkCodeGen
(
gm_prop
,
max_memory
)
graph
.
set_codegen
(
codegen
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
gm
.
recompile
()
# print
# print
code
=
graph
.
python_code
(
"self"
).
src
#
code = graph.python_code("self").src
print
(
code
)
#
print(code)
return
gm
return
gm
...
...
tests/test_autochunk/test_autochunk.py
View file @
1a6d2a74
...
@@ -4,7 +4,7 @@ import torch.fx
...
@@ -4,7 +4,7 @@ import torch.fx
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
colossalai
import
colossalai
from
colossalai.autochunk.chunk_codegen
import
ChunkCodeGen
from
colossalai.autochunk.
auto
chunk_codegen
import
Auto
ChunkCodeGen
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx
import
ColoTracer
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.graph_module
import
ColoGraphModule
...
@@ -82,7 +82,7 @@ def _run_offload_codegen(rank):
...
@@ -82,7 +82,7 @@ def _run_offload_codegen(rank):
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
)
codegen
=
ChunkCodeGen
(
gm_prop
)
codegen
=
Auto
ChunkCodeGen
(
gm_prop
)
graph
.
set_codegen
(
codegen
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
gm
.
recompile
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment