Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7d4abaa5
Commit
7d4abaa5
authored
Jan 10, 2023
by
oahzxl
Browse files
add doc
parent
1be0ac3c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
113 additions
and
16 deletions
+113
-16
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+87
-12
colossalai/autochunk/estimate_memory.py
colossalai/autochunk/estimate_memory.py
+19
-3
colossalai/autochunk/reorder_graph.py
colossalai/autochunk/reorder_graph.py
+7
-1
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
7d4abaa5
...
...
@@ -20,11 +20,22 @@ from .search_chunk import SearchChunk
from
.utils
import
delete_free_var_from_last_use
,
find_idx_by_name
,
get_node_shape
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
def
_gen_chunk_slice_dim
(
chunk_dim
:
int
,
chunk_indice_name
:
str
,
shape
:
List
)
->
str
:
"""
Generate chunk slice string, eg. [:, :, chunk_idx_name:chunk_idx_name + chunk_size, :]
Args:
chunk_dim (int)
chunk_indice_name (str): chunk indice name
shape (List): node shape
Returns:
new_shape (str): return slice
"""
new_shape
=
"["
for
idx
,
i
in
enumerate
(
shape
):
for
idx
,
_
in
enumerate
(
shape
):
if
idx
==
chunk_dim
:
new_shape
+=
"%s:%s + chunk_size"
%
(
chunk_i
dx
_name
,
chunk_i
dx
_name
)
new_shape
+=
"%s:%s + chunk_size"
%
(
chunk_i
ndice
_name
,
chunk_i
ndice
_name
)
else
:
new_shape
+=
":"
new_shape
+=
", "
...
...
@@ -32,7 +43,26 @@ def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
return
new_shape
def
_gen_loop_start
(
chunk_input
,
chunk_output
,
chunk_ouput_dim
,
chunk_size
=
2
):
def
_gen_loop_start
(
chunk_input
:
List
[
Node
],
chunk_output
:
Node
,
chunk_ouput_dim
:
int
,
chunk_size
=
2
)
->
str
:
"""
Generate chunk loop start
eg. chunk_result = torch.empty([100, 100], dtype=input_node.dtype, device=input_node.device)
chunk_size = 32
for chunk_idx in range(0, 100, 32):
......
Args:
chunk_input (List[Node]): chunk input node
chunk_output (Node): chunk output node
chunk_ouput_dim (int): chunk output node chunk dim
chunk_size (int): chunk size. Defaults to 2.
Returns:
context (str): generated str
"""
input_node
=
chunk_input
[
0
]
out_shape
=
get_node_shape
(
chunk_output
)
out_str
=
str
(
list
(
out_shape
))
...
...
@@ -45,8 +75,28 @@ def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2):
def
_gen_loop_end
(
chunk_inputs
,
chunk_non_compute_inputs
,
chunk_outputs
,
chunk_outputs_dim
,
node_list
):
chunk_inputs
:
List
[
Node
],
chunk_non_compute_inputs
:
List
[
Node
],
chunk_outputs
:
Node
,
chunk_outputs_dim
:
int
,
node_list
:
List
[
Node
],
)
->
str
:
"""
Generate chunk loop end
eg. chunk_result[chunk_idx:chunk_idx + chunk_size] = output_node
output_node = chunk_result; xx = None; xx = None
Args:
chunk_inputs (List[Node]): chunk input node
chunk_non_compute_inputs (List[Node]): input node without chunk
chunk_outputs (Node): chunk output node
chunk_outputs_dim (int): chunk output node chunk dim
node_list (List)
Returns:
context (str): generated str
"""
chunk_outputs_name
=
chunk_outputs
.
name
chunk_outputs_idx
=
find_idx_by_name
(
chunk_outputs_name
,
node_list
)
chunk_output_shape
=
chunk_outputs
.
meta
[
"tensor_meta"
].
shape
...
...
@@ -76,7 +126,10 @@ def _gen_loop_end(
return
context
def
_replace_name
(
context
,
name_from
,
name_to
):
def
_replace_name
(
context
:
str
,
name_from
:
str
,
name_to
:
str
)
->
str
:
"""
replace node name
"""
patterns
=
[(
" "
,
" "
),
(
" "
,
"."
),
(
" "
,
","
),
(
"("
,
")"
),
(
"("
,
","
),
(
" "
,
")"
)]
for
p
in
patterns
:
source
=
p
[
0
]
+
name_from
+
p
[
1
]
...
...
@@ -86,7 +139,10 @@ def _replace_name(context, name_from, name_to):
return
context
def
_replace_reshape_size
(
context
,
node_name
,
reshape_size_dict
):
def
_replace_reshape_size
(
context
:
str
,
node_name
:
str
,
reshape_size_dict
:
Dict
)
->
str
:
"""
replace reshape size, some may have changed due to chunk
"""
if
node_name
not
in
reshape_size_dict
:
return
context
for
size_name
,
size_value
in
reshape_size_dict
[
node_name
].
items
():
...
...
@@ -94,7 +150,17 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
return
context
def
_replace_ones_like
(
search_chunk
:
SearchChunk
,
chunk_infos
,
region_idx
,
node_idx
,
node
,
body
):
def
_replace_ones_like
(
search_chunk
:
SearchChunk
,
chunk_infos
:
List
[
Dict
],
region_idx
:
int
,
node_idx
:
int
,
node
:
Node
,
body
:
List
[
str
],
)
->
List
[
str
]:
"""
add chunk slice for new tensor op such as ones like
"""
if
"ones_like"
in
node
.
name
:
meta_node
=
search_chunk
.
trace_indice
.
node_list
[
node_idx
]
chunk_dim
=
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
"chunk_dim"
]
...
...
@@ -114,7 +180,16 @@ def _replace_ones_like(search_chunk: SearchChunk, chunk_infos, region_idx, node_
return
body
def
_replace_input_var
(
chunk_inputs
,
region_idx
,
chunk_inputs_dim
,
node_idx
,
body
):
def
_replace_input_node
(
chunk_inputs
:
List
[
Node
],
region_idx
:
int
,
chunk_inputs_dim
:
Dict
,
node_idx
:
int
,
body
:
List
[
str
],
)
->
List
[
str
]:
"""
add chunk slice for input nodes
"""
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
if
idx
==
node_idx
:
...
...
@@ -138,7 +213,7 @@ def emit_code_with_chunk(
"""
Emit code with chunk according to chunk_infos.
It will generate a for loop in chunk regions, and
It will generate a for loop in chunk regions, and
replace inputs and outputs of regions with chunked variables.
Args:
...
...
@@ -193,7 +268,7 @@ def emit_code_with_chunk(
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
# replace input var with chunk var
body
=
_replace_input_
var
(
body
=
_replace_input_
node
(
chunk_inputs
,
region_idx
,
chunk_inputs_dim
,
node_idx
,
body
)
# ones like
...
...
colossalai/autochunk/estimate_memory.py
View file @
7d4abaa5
...
...
@@ -15,6 +15,10 @@ from .utils import (
class
EstimateMemory
(
object
):
"""
Estimate memory with chunk
"""
def
__init__
(
self
)
->
None
:
pass
...
...
@@ -31,8 +35,6 @@ class EstimateMemory(object):
}
out_size
=
activation_size
(
fwd_out
)
out_node
=
[
n
.
name
]
if
out_size
>
0
else
[]
# if any(i in n.name for i in ['transpose', 'permute', 'view']):
# out_size = 0
return
out_size
,
out_node
def
_get_output_node_size
(
self
,
n
):
...
...
@@ -184,10 +186,24 @@ class EstimateMemory(object):
def
estimate_chunk_inference_mem
(
self
,
node_list
,
node_list
:
List
,
chunk_infos
=
None
,
print_mem
=
False
,
):
"""
Estimate inference memory with chunk
Args:
node_list (List): _description_
chunk_infos (Dict): Chunk information. Defaults to None.
print_mem (bool): Wether to print peak memory of every node. Defaults to False.
Returns:
act_memory_peak_log (List): peak memory of every node
act_memory_after_node_log (List): memory after excuting every node
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
act_memory
=
0.0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
...
...
colossalai/autochunk/reorder_graph.py
View file @
7d4abaa5
...
...
@@ -3,6 +3,10 @@ from .utils import find_idx_by_name
class
ReorderGraph
(
object
):
"""
Reorder node list and indice trace list
"""
def
__init__
(
self
,
trace_indice
:
TraceIndice
)
->
None
:
self
.
trace_indice
=
trace_indice
self
.
all_reorder_map
=
{
...
...
@@ -60,7 +64,9 @@ class ReorderGraph(object):
def
_reorder_idx_trace
(
self
,
reorder_map
):
# reorder list
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_indice
.
indice_trace_list
))]
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_indice
.
indice_trace_list
))
]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_idx_trace_list
[
new_idx
]
=
self
.
trace_indice
.
indice_trace_list
[
old_idx
]
self
.
trace_indice
.
indice_trace_list
=
new_idx_trace_list
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment