Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1a6d2a74
Commit
1a6d2a74
authored
Jan 06, 2023
by
oahzxl
Browse files
take apart chunk code gen
parent
d1f07731
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
2408 additions
and
6 deletions
+2408
-6
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+497
-0
colossalai/autochunk/chunk_region_search.py
colossalai/autochunk/chunk_region_search.py
+211
-0
colossalai/autochunk/chunk_selector.py
colossalai/autochunk/chunk_selector.py
+221
-0
colossalai/autochunk/index_tracer.py
colossalai/autochunk/index_tracer.py
+1056
-0
colossalai/autochunk/memory_estiamtor.py
colossalai/autochunk/memory_estiamtor.py
+318
-0
colossalai/autochunk/utils.py
colossalai/autochunk/utils.py
+95
-0
tests/test_autochunk/benchmark_autochunk.py
tests/test_autochunk/benchmark_autochunk.py
+8
-4
tests/test_autochunk/test_autochunk.py
tests/test_autochunk/test_autochunk.py
+2
-2
No files found.
colossalai/autochunk/autochunk_codegen.py
0 → 100644
View file @
1a6d2a74
This diff is collapsed.
Click to expand it.
colossalai/autochunk/chunk_region_search.py
0 → 100644
View file @
1a6d2a74
from
.index_tracer
import
IndexTracer
from
.memory_estiamtor
import
MemoryEstimator
from
.chunk_selector
import
ChunkSelector
import
copy
from
.utils
import
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
get_node_shape
class
ChunkRegionSearch
(
object
):
def
__init__
(
self
,
gm
,
max_memory
=
None
)
->
None
:
self
.
gm
=
gm
self
.
index_tracer
=
IndexTracer
(
list
(
gm
.
graph
.
nodes
))
self
.
index_tracer
.
trace_index
()
self
.
memory_estimator
=
MemoryEstimator
(
self
.
index_tracer
)
self
.
chunk_selector
=
ChunkSelector
(
self
.
index_tracer
,
self
.
memory_estimator
,
max_memory
=
max_memory
)
def
_find_peak_node
(
self
,
mem_peak
):
max_value
=
max
(
mem_peak
)
max_idx
=
mem_peak
.
index
(
max_value
)
return
max_idx
def
_get_free_var
(
self
):
free_var_idx
=
[]
for
idx
,
n
in
enumerate
(
self
.
index_tracer
.
node_list
):
if
n
.
op
==
"placeholder"
:
free_var_idx
.
append
(
idx
)
return
free_var_idx
def
_get_min_free_var
(
self
,
active_node_list
,
free_vars
):
min_len
=
999
for
idx
,
n
in
enumerate
(
active_node_list
):
if
idx
in
free_vars
:
continue
if
len
(
n
)
<
min_len
:
min_len
=
len
(
n
)
return
min_len
def
_search_max_chunk_region
(
self
,
active_node
,
peak_node
,
chunk_regions
):
free_vars
=
self
.
_get_free_var
()
free_var_num
=
len
(
free_vars
)
active_node_num
=
[
len
(
i
)
for
i
in
active_node
]
min_active_node_num
=
min
(
active_node_num
[
free_var_num
:])
threshold
=
max
(
free_var_num
,
min_active_node_num
)
# from peak_node to free_var
inside_flag
=
False
chunk_region_start
=
free_var_num
for
i
in
range
(
peak_node
,
-
1
,
-
1
):
if
active_node_num
[
i
]
<=
threshold
:
inside_flag
=
True
if
inside_flag
and
active_node_num
[
i
]
>
threshold
:
chunk_region_start
=
i
+
1
break
# from peak_node to len-2
inside_flag
=
False
chunk_region_end
=
len
(
active_node
)
-
1
for
i
in
range
(
peak_node
,
len
(
active_node
)):
if
active_node_num
[
i
]
<=
threshold
:
inside_flag
=
True
if
inside_flag
and
active_node_num
[
i
]
>
threshold
:
chunk_region_end
=
i
break
for
i
in
chunk_regions
:
region
=
i
[
"region"
]
if
chunk_region_start
>=
region
[
0
]
and
chunk_region_end
<=
region
[
1
]:
return
None
elif
(
region
[
0
]
<=
chunk_region_start
<=
region
[
1
]
and
chunk_region_end
>
region
[
1
]
):
chunk_region_start
=
region
[
1
]
+
1
elif
(
region
[
0
]
<=
chunk_region_end
<=
region
[
1
]
and
chunk_region_start
<
region
[
0
]
):
chunk_region_end
=
region
[
0
]
-
1
return
chunk_region_start
,
chunk_region_end
def
_is_not_compute
(
self
,
trace
,
chunk_range
,
dim_idx
):
if
trace
[
"idx"
][
dim_idx
]
not
in
trace
[
"compute"
]:
return
True
if
trace
[
"idx"
][
dim_idx
]
in
trace
[
"compute"
]
and
all
(
i
<
chunk_range
[
0
]
or
i
>
chunk_range
[
1
]
for
i
in
trace
[
"compute"
][
trace
[
"idx"
][
dim_idx
]]
):
return
True
return
False
def
_find_free_dim
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
):
start_traces
=
input_trace
[
start_idx
]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
index_tracer
.
node_list
[
end_idx
]
chunk_infos
=
[]
for
end_dim
,
_
in
enumerate
(
end_trace
[
"idx"
]):
if
len
(
start_traces
)
>
1
:
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
_
in
enumerate
(
start_trace
[
"idx"
]):
# dim size cannot be 1
if
(
get_node_shape
(
end_node
)[
end_dim
]
==
1
or
get_node_shape
(
start_node
)[
start_dim
]
==
1
):
continue
# check index source align
if
not
self
.
index_tracer
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
continue
# check index copmute
if
not
self
.
index_tracer
.
check_index_compute
(
start_idx
,
end_dim
,
end_node
,
end_idx
):
continue
# flow search
chunk_info
=
self
.
index_tracer
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
)
if
chunk_info
is
None
:
continue
# check index copmute
if
not
self
.
index_tracer
.
check_index_duplicate
(
chunk_info
):
continue
chunk_infos
.
append
(
chunk_info
)
return
chunk_infos
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
,
peak_node
):
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
index_tracer
.
idx_trace_list
)
input_trace
=
[]
# trace of a node's input nodes
for
_
,
n
in
enumerate
(
self
.
index_tracer
.
node_list
):
cur_trace
=
{}
for
arg
in
n
.
args
:
if
type
(
arg
)
==
type
(
n
)
and
not
is_non_compute_node_except_placeholder
(
arg
):
cur_trace
[
arg
]
=
self
.
index_tracer
.
_find_trace_from_node
(
arg
)
input_trace
.
append
(
cur_trace
)
for
start_idx
in
range
(
max_chunk_region
[
0
],
peak_node
+
1
):
for
end_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
# skip non compute nodes
if
is_non_compute_node
(
self
.
index_tracer
.
node_list
[
start_idx
]
)
or
is_non_compute_node
(
self
.
index_tracer
.
node_list
[
end_idx
]):
continue
# select free dim
chunk_info
=
self
.
_find_free_dim
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
if
len
(
chunk_info
)
>
0
:
possible_chunk_region
.
extend
(
chunk_info
)
return
possible_chunk_region
def
_step_search
(
self
,
mem_peak
,
active_node
,
chunk_regions
):
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
max_chunk_region
=
self
.
_search_max_chunk_region
(
active_node
,
peak_node
,
chunk_regions
)
if
max_chunk_region
==
None
:
return
None
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
)
best_chunk_region
=
self
.
chunk_selector
.
_select_best_chunk_region
(
possible_chunk_regions
,
chunk_regions
,
peak_node
,
max_chunk_region
,
mem_peak
)
best_chunk_region
=
self
.
index_tracer
.
reorder_all
(
best_chunk_region
)
return
best_chunk_region
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
sorted_init_mem_peak
=
sorted
(
init_mem_peak
)
if
max
(
mem_peak
)
<
sorted_init_mem_peak
[
int
(
len
(
sorted_init_mem_peak
)
*
0.5
)]:
return
True
return
False
def
search_region
(
self
):
chunk_infos
=
[]
(
init_mem_peak
,
_
,
active_node
,
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
)
mem_peak
=
init_mem_peak
while
True
:
chunk_info
=
self
.
_step_search
(
mem_peak
,
active_node
,
chunk_infos
)
if
chunk_info
is
None
:
break
chunk_infos
.
append
(
chunk_info
)
(
mem_peak
,
_
,
active_node
,
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
chunk_infos
)
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
chunk_infos
,
print_mem
=
True
)
return
chunk_infos
colossalai/autochunk/chunk_selector.py
0 → 100644
View file @
1a6d2a74
from
.index_tracer
import
IndexTracer
from
.memory_estiamtor
import
MemoryEstimator
from
.utils
import
is_non_compute_node
class
ChunkSelector
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
,
memory_estimator
:
MemoryEstimator
,
max_memory
=
None
,
):
self
.
index_tracer
=
index_tracer
self
.
memory_estimator
=
memory_estimator
if
max_memory
is
not
None
:
self
.
stratge
=
"fit_memory"
self
.
max_memory
=
max_memory
# MB
else
:
self
.
stratge
=
"min_memory"
def
_select_best_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
if
self
.
stratge
==
"min_memory"
:
best_region
=
self
.
_select_min_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
,
)
elif
self
.
stratge
==
"fit_memory"
:
best_region
=
self
.
_select_fit_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
,
)
else
:
raise
RuntimeError
()
return
best_region
def
_select_fit_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
# stop chunk if max memory satisfy memory limit
if
max
(
mem_peak
)
<
self
.
max_memory
:
return
None
# remove illegal regions
illegal_regions
=
[]
for
i
in
possible_chunk_regions
:
if
not
self
.
_is_legal_region
(
i
,
chunk_infos
):
illegal_regions
.
append
(
i
)
for
i
in
illegal_regions
:
if
i
in
possible_chunk_regions
:
possible_chunk_regions
.
remove
(
i
)
if
len
(
possible_chunk_regions
)
==
0
:
return
None
# get mem for chunk region
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
index_tracer
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
if
cur_chunk_region_max_peak
<
self
.
max_memory
:
regions_dict
.
append
(
{
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
,
}
)
# no region found
if
len
(
regions_dict
)
==
0
:
raise
RuntimeError
(
"Search failed. Try a larger memory threshold."
)
# select the min chunk len
chunk_len
=
[
i
[
"chunk_len"
]
for
i
in
regions_dict
]
best_region_idx
=
chunk_len
.
index
(
min
(
chunk_len
))
best_region
=
regions_dict
[
best_region_idx
]
# get max chunk size
best_region
=
self
.
_get_fit_chunk_size
(
best_region
,
chunk_infos
)
return
best_region
def
_get_fit_chunk_size
(
self
,
chunk_region_dict
,
chunk_infos
):
chunk_size
=
1
reorder_chunk_info
=
chunk_region_dict
[
"reorder_chunk_info"
]
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_max_mem
=
0
# search a region
while
cur_chunk_max_mem
<
self
.
max_memory
:
chunk_size
*=
2
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_infos
=
chunk_infos
+
[
reorder_chunk_info
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
reorder_chunk_info
[
"region"
][
0
]
:
reorder_chunk_info
[
"region"
][
1
]
+
1
]
)
# search exact size
chunk_info
=
chunk_region_dict
[
"chunk_info"
]
chunk_info
[
"chunk_size"
]
=
self
.
_chunk_size_binary_search
(
chunk_size
//
2
,
chunk_size
,
chunk_region_dict
,
chunk_infos
)
return
chunk_info
def
_chunk_size_binary_search
(
self
,
l
,
r
,
chunk_region_dict
,
chunk_infos
):
if
l
>=
16
:
gap
=
4
else
:
gap
=
1
chunk_info
=
chunk_region_dict
[
"reorder_chunk_info"
]
while
r
>=
l
+
gap
:
mid
=
int
((
l
+
r
)
/
2
+
0.5
)
chunk_info
[
"chunk_size"
]
=
mid
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]
:
chunk_info
[
"region"
][
1
]
+
1
]
)
if
cur_chunk_max_mem
>=
self
.
max_memory
:
r
=
mid
-
gap
else
:
l
=
mid
+
gap
return
l
def
_get_compute_node_num
(
self
,
start
,
end
):
count
=
0
for
i
in
self
.
index_tracer
.
node_list
[
start
:
end
+
1
]:
if
not
is_non_compute_node
(
i
):
count
+=
1
return
count
def
_select_min_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
# remove illegal regions
illegal_regions
=
[]
for
i
in
possible_chunk_regions
:
if
not
self
.
_is_legal_region
(
i
,
chunk_infos
):
illegal_regions
.
append
(
i
)
for
i
in
illegal_regions
:
if
i
in
possible_chunk_regions
:
possible_chunk_regions
.
remove
(
i
)
if
len
(
possible_chunk_regions
)
==
0
:
return
None
# get mem for chunk region
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
index_tracer
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
regions_dict
.
append
(
{
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
,
}
)
# select the min mem
chunk_max_mem
=
[
i
[
"chunk_max_mem"
]
for
i
in
regions_dict
]
best_region_idx
=
chunk_max_mem
.
index
(
min
(
chunk_max_mem
))
best_region
=
regions_dict
[
best_region_idx
][
"chunk_info"
]
if
best_region
is
not
None
:
best_region
[
"chunk_size"
]
=
1
return
best_region
def
_is_legal_region
(
self
,
cur_chunk_info
,
chunk_infos
):
(
chunk_region_start
,
chunk_region_end
)
=
cur_chunk_info
[
"region"
]
if
cur_chunk_info
in
chunk_infos
:
return
False
if
chunk_region_end
<
chunk_region_start
:
return
False
for
i
in
chunk_infos
:
region
=
i
[
"region"
]
if
not
(
(
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
or
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])
):
return
False
return
True
colossalai/autochunk/
chunk_codegen
.py
→
colossalai/autochunk/
index_tracer
.py
View file @
1a6d2a74
This diff is collapsed.
Click to expand it.
colossalai/autochunk/memory_estiamtor.py
0 → 100644
View file @
1a6d2a74
import
copy
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
import
torch
from
torch.fx.node
import
Node
,
map_arg
from
colossalai.fx.profiler
import
activation_size
,
parameter_size
from
.index_tracer
import
IndexTracer
from
.utils
import
(
delete_free_var_from_last_use
,
find_idx_by_name
,
get_node_shape
,
is_non_compute_node_except_placeholder
,
)
class
MemoryEstimator
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
)
->
None
:
pass
def
_get_meta_node_size
(
self
,
x
):
x
=
x
.
meta
[
"tensor_meta"
]
x
=
x
.
numel
*
torch
.
tensor
([],
dtype
=
x
.
dtype
).
element_size
()
return
x
def
_get_output_node
(
self
,
n
):
fwd_out
=
{
x
.
uuid
:
x
for
x
in
n
.
meta
[
"fwd_out"
]
if
isinstance
(
x
,
torch
.
Tensor
)
and
hasattr
(
x
,
"uuid"
)
}
out_size
=
activation_size
(
fwd_out
)
out_node
=
[
n
.
name
]
if
out_size
>
0
else
[]
# if any(i in n.name for i in ['transpose', 'permute', 'view']):
# out_size = 0
return
out_size
,
out_node
def
_get_output_node_size
(
self
,
n
):
return
self
.
_get_output_node
(
n
)[
0
]
def
_add_active_node
(
self
,
n
,
active_list
):
new_active
=
self
.
_get_output_node
(
n
)[
1
]
if
n
.
op
==
"placeholder"
:
new_active
.
append
(
n
.
name
)
for
i
in
new_active
:
if
i
not
in
active_list
:
active_list
.
append
(
i
)
def
_get_delete_node
(
self
,
user
,
user_to_last_uses
,
to_keep
=
None
):
delete_size
=
0
delete_node
=
[]
if
user
.
op
not
in
(
"output"
,):
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
to_keep
is
not
None
:
keep_list
=
[]
for
n
in
nodes_to_delete
:
if
n
.
name
in
to_keep
:
keep_list
.
append
(
n
)
for
n
in
keep_list
:
if
n
in
nodes_to_delete
:
nodes_to_delete
.
remove
(
n
)
if
len
(
nodes_to_delete
):
out_node
=
[
self
.
_get_output_node
(
i
)
for
i
in
nodes_to_delete
]
delete_size
=
sum
([
i
[
0
]
for
i
in
out_node
])
for
i
in
range
(
len
(
out_node
)):
if
out_node
[
i
][
0
]
>
0
:
delete_node
.
append
(
out_node
[
i
][
1
][
0
])
elif
nodes_to_delete
[
i
].
op
==
"placeholder"
:
delete_node
.
append
(
nodes_to_delete
[
i
].
name
)
# elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']):
# delete_node.append(nodes_to_delete[i].name)
return
delete_size
,
delete_node
def
_get_delete_node_size
(
self
,
user
,
user_to_last_uses
,
to_keep
):
return
self
.
_get_delete_node
(
user
,
user_to_last_uses
,
to_keep
)[
0
]
def
_remove_deactive_node
(
self
,
user
,
user_to_last_uses
,
active_list
):
delete_node
=
self
.
_get_delete_node
(
user
,
user_to_last_uses
)[
1
]
for
i
in
delete_node
:
if
i
in
active_list
:
active_list
.
remove
(
i
)
def
_get_chunk_inputs_size
(
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
nodes_to_delete
=
[]
for
chunk_input
in
chunk_inputs
+
chunk_inputs_non_chunk
:
chunk_input_users
=
chunk_input
.
users
.
keys
()
chunk_input_users_idx
=
[
find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
if
all
(
i
<=
chunk_end_idx
for
i
in
chunk_input_users_idx
):
if
chunk_input
not
in
nodes_to_delete
:
nodes_to_delete
.
append
(
chunk_input
)
out_node
=
[
self
.
_get_output_node
(
i
)
for
i
in
nodes_to_delete
]
delete_size
=
sum
([
i
[
0
]
for
i
in
out_node
])
return
delete_size
def
_get_last_usr
(
self
,
nodes
):
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
return
user_to_last_uses
def
_get_contiguous_memory
(
self
,
node
,
not_contiguous_list
,
delete
=
False
):
mem
=
0
not_contiguous_ops
=
[
"permute"
]
inherit_contiguous_ops
=
[
"transpose"
,
"view"
]
if
node
.
op
==
"call_function"
and
any
(
n
in
node
.
name
for
n
in
[
"matmul"
,
"reshape"
]
):
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# matmul won't change origin tensor, but create a tmp copy
mem
+=
self
.
_get_output_node_size
(
n
)
elif
node
.
op
==
"call_module"
:
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# module will just make origin tensor to contiguous
if
delete
:
not_contiguous_list
.
remove
(
n
)
elif
node
.
op
==
"call_method"
and
any
(
i
in
node
.
name
for
i
in
not_contiguous_ops
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
return
mem
def
_get_chunk_ratio
(
self
,
node
,
chunk_node_dim
,
chunk_size
):
if
node
not
in
chunk_node_dim
:
return
1.0
node_shape
=
get_node_shape
(
node
)
chunk_dim
=
chunk_node_dim
[
node
][
"chunk_dim"
]
if
chunk_dim
is
None
:
return
1.0
else
:
return
float
(
chunk_size
)
/
node_shape
[
chunk_dim
]
def
_get_chunk_delete_node_size
(
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
):
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
# return 0
if
user
.
op
in
(
"placeholder"
,
"output"
):
return
0
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
delete_size
=
0
for
n
in
nodes_to_delete
:
if
n
.
name
in
chunk_inputs_names
:
continue
delete_size
+=
self
.
_get_output_node_size
(
n
)
*
chunk_ratio
return
delete_size
def
_print_mem_log
(
self
,
log
,
nodes
,
title
=
None
):
if
title
:
print
(
title
)
for
idx
,
(
l
,
n
)
in
enumerate
(
zip
(
log
,
nodes
)):
print
(
"%s:%.2f
\t
"
%
(
n
.
name
,
l
),
end
=
""
)
if
(
idx
+
1
)
%
3
==
0
:
print
(
""
)
print
(
"
\n
"
)
def
_print_compute_op_mem_log
(
self
,
log
,
nodes
,
title
=
None
):
if
title
:
print
(
title
)
for
idx
,
(
l
,
n
)
in
enumerate
(
zip
(
log
,
nodes
)):
if
n
.
op
in
[
"placeholder"
,
"get_attr"
,
"output"
]:
continue
if
any
(
i
in
n
.
name
for
i
in
[
"getitem"
,
"getattr"
]):
continue
print
(
"%s:%.2f
\t
"
%
(
n
.
name
,
l
),
end
=
""
)
if
(
idx
+
1
)
%
3
==
0
:
print
(
""
)
print
(
"
\n
"
)
def
estimate_chunk_inference_mem
(
self
,
node_list
,
chunk_infos
=
None
,
print_mem
=
False
,
):
act_memory
=
0.0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
active_node_list
=
[]
active_node_list_log
=
[]
not_contiguous_list
=
[]
user_to_last_uses
=
self
.
_get_last_usr
(
node_list
)
user_to_last_uses_no_free_var
=
self
.
_get_last_usr
(
node_list
)
delete_free_var_from_last_use
(
user_to_last_uses_no_free_var
)
use_chunk
=
True
if
chunk_infos
is
not
None
else
False
chunk_within
=
False
chunk_region_idx
=
None
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_inputs_names
=
[]
if
use_chunk
:
chunk_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
chunk_starts
=
[
i
[
0
]
for
i
in
chunk_regions
]
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_node_dim
=
[
i
[
"node_chunk_dim"
]
for
i
in
chunk_infos
]
chunk_sizes
=
[
i
[
"chunk_size"
]
if
"chunk_size"
in
i
else
1
for
i
in
chunk_infos
]
for
idx
,
node
in
enumerate
(
node_list
):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if
use_chunk
and
idx
in
chunk_starts
:
chunk_within
=
True
chunk_region_idx
=
chunk_starts
.
index
(
idx
)
act_memory
+=
self
.
_get_output_node_size
(
chunk_outputs
[
chunk_region_idx
]
)
/
(
1024
**
2
)
# determine chunk ratio for current node
if
chunk_within
:
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
chunk_node_dim
[
chunk_region_idx
],
chunk_sizes
[
chunk_region_idx
],
)
# if node is placeholder, just add the size of the node
if
node
.
op
==
"placeholder"
:
act_memory
+=
self
.
_get_meta_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
# skip output
elif
node
.
op
==
"output"
:
continue
# no change for non compute node
elif
is_non_compute_node_except_placeholder
(
node
):
act_memory_peak_log
.
append
(
act_memory
)
# node is a compute op
# calculate tmp, output node and delete node memory
else
:
# forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
act_memory
+=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
+=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
*
chunk_ratio
/
(
1024
**
2
)
)
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
if
chunk_within
:
act_memory
-=
self
.
_get_chunk_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_ratio
,
chunk_inputs_names
,
)
/
(
1024
**
2
)
else
:
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_inputs_names
)
/
(
1024
**
2
)
# log active node, only effective without chunk
self
.
_add_active_node
(
node
,
active_node_list
)
self
.
_remove_deactive_node
(
node
,
user_to_last_uses
,
active_node_list
)
# if node in chunk end nodes, restore chunk settings
if
use_chunk
and
idx
in
chunk_ends
:
act_memory
-=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
-=
self
.
_get_chunk_inputs_size
(
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_non_chunk
[
chunk_region_idx
],
node_list
,
chunk_regions
[
chunk_region_idx
][
1
],
)
/
(
1024
**
2
)
chunk_within
=
False
chunk_ratio
=
1
chunk_region_idx
=
None
act_memory_after_node_log
.
append
(
act_memory
)
active_node_list_log
.
append
(
copy
.
deepcopy
(
active_node_list
))
if
print_mem
:
print
(
"with chunk"
if
use_chunk
else
"without chunk"
)
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
self
.
_print_compute_op_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
# self._print_compute_op_mem_log(
# act_memory_after_node_log, node_list, "after"
# )
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return
act_memory_peak_log
,
act_memory_after_node_log
,
active_node_list_log
colossalai/autochunk/utils.py
0 → 100644
View file @
1a6d2a74
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
from
torch.fx.node
import
Node
def
is_non_compute_node
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
def
get_node_shape
(
node
):
if
hasattr
(
node
.
meta
[
"tensor_meta"
],
"shape"
):
return
node
.
meta
[
"tensor_meta"
].
shape
return
None
def
is_non_compute_node_except_placeholder
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
def
is_non_compute_node_except_placeholder_output
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
def
find_idx_by_name
(
name
,
nodes_list
):
for
idx
,
node
in
enumerate
(
nodes_list
):
if
node
.
name
==
name
:
return
idx
raise
RuntimeError
(
"name %s not found in node list"
%
name
)
def
delete_free_var_from_last_use
(
user_to_last_uses
):
for
key
,
value
in
user_to_last_uses
.
items
():
for
n
in
value
:
if
n
.
op
==
"placeholder"
:
user_to_last_uses
[
key
].
remove
(
n
)
def
find_chunk_all_input_nodes
(
nodes
:
List
[
Node
]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes
=
[]
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
if
input_node
not
in
nodes
and
input_node
not
in
input_nodes
:
input_nodes
.
append
(
input_node
)
return
input_nodes
def
find_chunk_compute_input_and_output_nodes
(
nodes
:
List
[
Node
]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes
=
[]
output_nodes
=
[]
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
if
(
input_node
not
in
nodes
and
input_node
not
in
input_nodes
and
not
is_non_compute_node_except_placeholder
(
input_node
)
):
input_nodes
.
append
(
input_node
)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for
node
in
nodes
:
for
output_node
in
node
.
users
.
keys
():
if
(
output_node
not
in
nodes
and
node
not
in
output_nodes
and
not
is_non_compute_node_except_placeholder_output
(
output_node
)
):
output_nodes
.
append
(
node
)
return
input_nodes
,
output_nodes
tests/test_autochunk/benchmark_autochunk.py
View file @
1a6d2a74
...
...
@@ -3,7 +3,7 @@ import time
import
torch
import
torch.fx
from
colossalai.autochunk.chunk_codegen
import
ChunkCodeGen
from
colossalai.autochunk.
auto
chunk_codegen
import
Auto
ChunkCodeGen
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
...
...
@@ -49,25 +49,29 @@ def _build_autochunk(model, max_memory, node, pair):
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
},
)
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
# now run it twice to get meta info in graph module, not necessary
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
# set code_gen
codegen
=
ChunkCodeGen
(
gm_prop
,
max_memory
)
codegen
=
Auto
ChunkCodeGen
(
gm_prop
,
max_memory
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
# print
code
=
graph
.
python_code
(
"self"
).
src
print
(
code
)
#
code = graph.python_code("self").src
#
print(code)
return
gm
...
...
tests/test_autochunk/test_autochunk.py
View file @
1a6d2a74
...
...
@@ -4,7 +4,7 @@ import torch.fx
import
torch.multiprocessing
as
mp
import
colossalai
from
colossalai.autochunk.chunk_codegen
import
ChunkCodeGen
from
colossalai.autochunk.
auto
chunk_codegen
import
Auto
ChunkCodeGen
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.graph_module
import
ColoGraphModule
...
...
@@ -82,7 +82,7 @@ def _run_offload_codegen(rank):
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
codegen
=
ChunkCodeGen
(
gm_prop
)
codegen
=
Auto
ChunkCodeGen
(
gm_prop
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment