Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0f02b8c6
Unverified
Commit
0f02b8c6
authored
Jan 19, 2023
by
Ziyue Jiang
Committed by
GitHub
Jan 19, 2023
Browse files
add avg partition (#2483)
Co-authored-by:
Ziyue Jiang
<
ziyue.jiang@gmail.com
>
parent
99d9713b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
1 deletion
+38
-1
colossalai/fx/passes/adding_split_node_pass.py
colossalai/fx/passes/adding_split_node_pass.py
+36
-0
colossalai/fx/passes/meta_info_prop.py
colossalai/fx/passes/meta_info_prop.py
+2
-1
No files found.
colossalai/fx/passes/adding_split_node_pass.py
View file @
0f02b8c6
...
...
@@ -9,6 +9,40 @@ def pipe_split():
pass
def
avgcompute_split_pass
(
gm
:
torch
.
fx
.
GraphModule
,
pp_size
:
int
):
"""
In avgcompute_split_pass, we split module by the fwd flops.
"""
mod_graph
=
gm
.
graph
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node
=
list
(
mod_graph
.
nodes
)[
0
]
if
'tensor_meta'
not
in
check_node
.
meta
:
return
balanced_split_pass
(
gm
,
pp_size
)
total_fwd_flop
=
0
for
node
in
mod_graph
.
nodes
:
total_fwd_flop
+=
node
.
fwd_flop
partition_flop
=
total_fwd_flop
//
pp_size
accumulate_fwd_flop
=
0
for
node
in
mod_graph
.
nodes
:
if
pp_size
<=
1
:
break
if
'pipe_split'
in
node
.
name
:
continue
accumulate_fwd_flop
+=
node
.
fwd_flop
if
accumulate_fwd_flop
>=
partition_flop
:
total_fwd_flop
=
total_fwd_flop
-
accumulate_fwd_flop
accumulate_fwd_flop
=
0
pp_size
-=
1
partition_flop
=
total_fwd_flop
//
pp_size
with
mod_graph
.
inserting_after
(
node
):
split_node
=
mod_graph
.
create_node
(
'call_function'
,
pipe_split
)
gm
.
recompile
()
return
gm
def
avgnode_split_pass
(
gm
:
torch
.
fx
.
GraphModule
,
pp_size
:
int
):
"""
In avgnode_split_pass, simpliy split graph by node number.
...
...
@@ -104,8 +138,10 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
continue
accumulate_node_size
+=
node
.
node_size
if
accumulate_node_size
>=
partition_size
:
total_element_size
=
total_element_size
-
accumulate_node_size
accumulate_node_size
=
0
pp_size
-=
1
partition_size
=
total_element_size
//
pp_size
with
mod_graph
.
inserting_after
(
node
):
split_node
=
mod_graph
.
create_node
(
'call_function'
,
pipe_split
)
gm
.
recompile
()
...
...
colossalai/fx/passes/meta_info_prop.py
View file @
0f02b8c6
...
...
@@ -112,7 +112,8 @@ class MetaInfoProp(torch.fx.Interpreter):
n
.
meta
[
'tensor_meta'
]
=
tensor_meta
n
.
meta
=
{
**
n
.
meta
,
**
asdict
(
meta_info
)}
# extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr
(
n
,
'node_size'
,
activation_size
(
n
.
meta
.
get
(
'fwd_in'
,
0
))
+
activation_size
(
n
.
meta
.
get
(
'fwd_tmp'
,
0
)))
setattr
(
n
,
'node_size'
,
activation_size
(
n
.
meta
.
get
(
'fwd_out'
,
0
))
+
activation_size
(
n
.
meta
.
get
(
'fwd_tmp'
,
0
)))
setattr
(
n
,
'fwd_flop'
,
n
.
meta
.
get
(
'fwd_flop'
,
0
))
n
.
meta
[
'type'
]
=
type
(
result
)
# retain the autograd graph
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment