Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9e768b59
Commit
9e768b59
authored
Oct 10, 2023
by
zhuwenwen
Browse files
Merge branch 'main' of
https://github.com/hpcaitech/ColossalAI
parents
7bc5a8e3
8aed02b9
Changes
442
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
64 deletions
+75
-64
colossalai/fx/graph_module.py
colossalai/fx/graph_module.py
+33
-23
colossalai/fx/passes/adding_split_node_pass.py
colossalai/fx/passes/adding_split_node_pass.py
+42
-41
No files found.
Too many changes to show.
To preserve performance only
442 of 442+
files are displayed.
Plain diff
Email patch
colossalai/fx/graph_module.py
View file @
9e768b59
import
os
import
warnings
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Type
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
from
torch.nn.modules.module
import
_addindent
try
:
from
torch.fx.graph
import
Graph
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_PyTreeCodeGen
from
torch.fx.graph_module
import
GraphModule
,
_EvalCacheLoader
,
_exec_with_source
,
_forward_from_src
,
_WrappedCall
from
torch.fx.graph
import
Graph
,
PythonCode
,
_PyTreeCodeGen
from
torch.fx.graph_module
import
GraphModule
,
_exec_with_source
,
_forward_from_src
,
_WrappedCall
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
ActivationCheckpointCodeGen
COLOGM
=
True
except
:
from
torch.fx.graph
import
Graph
from
torch.fx.graph_module
import
GraphModule
COLOGM
=
False
if
COLOGM
:
class
ColoGraphModule
(
GraphModule
):
def
__init__
(
self
,
def
__init__
(
self
,
root
:
Union
[
torch
.
nn
.
Module
,
Dict
[
str
,
Any
]],
graph
:
Graph
,
class_name
:
str
=
'GraphModule'
,
ckpt_codegen
:
bool
=
True
):
class_name
:
str
=
"GraphModule"
,
ckpt_codegen
:
bool
=
True
,
):
if
ckpt_codegen
:
graph
.
set_codegen
(
ActivationCheckpointCodeGen
())
super
().
__init__
(
root
,
graph
,
class_name
)
...
...
@@ -60,7 +63,7 @@ if COLOGM:
if
isinstance
(
self
.
_graph
.
_codegen
,
_PyTreeCodeGen
):
self
.
_in_spec
=
self
.
_graph
.
_codegen
.
pytree_info
.
in_spec
self
.
_out_spec
=
self
.
_graph
.
_codegen
.
pytree_info
.
out_spec
python_code
=
self
.
_graph
.
python_code
(
root_module
=
'
self
'
)
python_code
=
self
.
_graph
.
python_code
(
root_module
=
"
self
"
)
self
.
_code
=
python_code
.
src
# To split ckpt functions code and forward code
...
...
@@ -83,7 +86,7 @@ if COLOGM:
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call
=
cls
.
__call__
if
"__call__"
in
vars
(
cls
)
else
None
if
'
_wrapped_call
'
not
in
vars
(
cls
):
if
"
_wrapped_call
"
not
in
vars
(
cls
):
cls
.
_wrapped_call
=
_WrappedCall
(
cls
,
cls_call
)
# type: ignore[attr-defined]
def
call_wrapped
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -108,7 +111,7 @@ if COLOGM:
"""
folder
=
Path
(
folder
)
Path
(
folder
).
mkdir
(
exist_ok
=
True
)
torch
.
save
(
self
.
state_dict
(),
folder
/
'
state_dict.pt
'
)
torch
.
save
(
self
.
state_dict
(),
folder
/
"
state_dict.pt
"
)
tab
=
" "
*
4
# we add import colossalai here
...
...
@@ -125,7 +128,13 @@ class {module_name}(torch.nn.Module):
def
_gen_model_repr
(
module_name
:
str
,
module
:
torch
.
nn
.
Module
)
->
Optional
[
str
]:
safe_reprs
=
[
nn
.
Linear
,
nn
.
Conv1d
,
nn
.
Conv2d
,
nn
.
Conv3d
,
nn
.
BatchNorm1d
,
nn
.
BatchNorm2d
,
nn
.
BatchNorm3d
nn
.
Linear
,
nn
.
Conv1d
,
nn
.
Conv2d
,
nn
.
Conv3d
,
nn
.
BatchNorm1d
,
nn
.
BatchNorm2d
,
nn
.
BatchNorm3d
,
]
if
type
(
module
)
in
safe_reprs
:
return
f
"
{
module
.
__repr__
()
}
"
...
...
@@ -136,10 +145,10 @@ class {module_name}(torch.nn.Module):
for
module_name
,
module
in
self
.
named_children
():
module_str
=
_gen_model_repr
(
module_name
,
module
)
if
module_str
is
None
:
module_file
=
folder
/
f
'
{
module_name
}
.pt
'
module_file
=
folder
/
f
"
{
module_name
}
.pt
"
torch
.
save
(
module
,
module_file
)
blobified_modules
.
append
(
module_name
)
module_repr
=
module
.
__repr__
().
replace
(
'
\r
'
,
' '
).
replace
(
'
\n
'
,
' '
)
module_repr
=
module
.
__repr__
().
replace
(
"
\r
"
,
" "
).
replace
(
"
\n
"
,
" "
)
module_str
=
f
"torch.load(r'
{
module_file
}
') #
{
module_repr
}
"
model_str
+=
f
"
{
tab
*
2
}
self.
{
module_name
}
=
{
module_str
}
\n
"
...
...
@@ -156,19 +165,20 @@ class {module_name}(torch.nn.Module):
model_str
+=
f
"
{
tab
*
2
}
self.load_state_dict(torch.load(r'
{
folder
}
/state_dict.pt'))
\n
"
model_str
+=
f
"
{
_addindent
(
self
.
code
,
4
)
}
\n
"
module_file
=
folder
/
'
module.py
'
module_file
=
folder
/
"
module.py
"
module_file
.
write_text
(
model_str
)
init_file
=
folder
/
'
__init__.py
'
init_file
.
write_text
(
'
from .module import *
'
)
init_file
=
folder
/
"
__init__.py
"
init_file
.
write_text
(
"
from .module import *
"
)
if
len
(
blobified_modules
)
>
0
:
warnings
.
warn
(
"Was not able to save the following children modules as reprs -"
f
"saved as pickled files instead:
{
blobified_modules
}
"
)
warnings
.
warn
(
"Was not able to save the following children modules as reprs -"
f
"saved as pickled files instead:
{
blobified_modules
}
"
)
else
:
class
ColoGraphModule
(
GraphModule
):
def
__init__
(
self
,
root
:
Union
[
torch
.
nn
.
Module
,
Dict
[
str
,
Any
]],
graph
:
Graph
,
class_name
:
str
=
'GraphModule'
):
def
__init__
(
self
,
root
:
Union
[
torch
.
nn
.
Module
,
Dict
[
str
,
Any
]],
graph
:
Graph
,
class_name
:
str
=
"GraphModule"
):
super
().
__init__
(
root
,
graph
,
class_name
)
colossalai/fx/passes/adding_split_node_pass.py
View file @
9e768b59
import
numpy
as
np
import
torch
import
tqdm
from
torch.fx
import
symbolic_trace
from
torch.fx.node
import
Node
from
colossalai.fx.passes.split_module
import
split_module
...
...
@@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
accumulate_bwd_flop
=
0
block_nodes
=
[]
for
node
in
gm
.
graph
.
nodes
:
if
'
block_split
'
in
node
.
name
:
if
"
block_split
"
in
node
.
name
:
continue
accumulate_fwd_flop
+=
node
.
fwd_flop
accumulate_bwd_flop
+=
node
.
bwd_flop
if
accumulate_fwd_flop
+
accumulate_bwd_flop
>=
per_block_flop
:
with
gm
.
graph
.
inserting_after
(
node
):
block_node
=
gm
.
graph
.
create_node
(
'
call_function
'
,
block_split
)
setattr
(
block_node
,
'
fwd_flop
'
,
accumulate_fwd_flop
)
setattr
(
block_node
,
'
bwd_flop
'
,
accumulate_bwd_flop
)
block_node
=
gm
.
graph
.
create_node
(
"
call_function
"
,
block_split
)
setattr
(
block_node
,
"
fwd_flop
"
,
accumulate_fwd_flop
)
setattr
(
block_node
,
"
bwd_flop
"
,
accumulate_bwd_flop
)
accumulate_fwd_flop
=
0
accumulate_bwd_flop
=
0
block_nodes
.
append
(
block_node
)
...
...
@@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
def
remove_blocks
(
gm
:
torch
.
fx
.
GraphModule
):
for
node
in
gm
.
graph
.
nodes
:
if
(
node
.
op
,
node
.
target
)
==
(
'
call_function
'
,
block_split
):
if
(
node
.
op
,
node
.
target
)
==
(
"
call_function
"
,
block_split
):
gm
.
graph
.
erase_node
(
node
)
...
...
@@ -55,8 +53,8 @@ def get_compute_costs(node_list):
num_nodes
=
len
(
node_list
)
all_compute_cost
=
np
.
full
((
num_nodes
,
num_nodes
),
np
.
inf
,
dtype
=
np
.
float64
)
for
start
in
tqdm
.
tqdm
(
range
(
num_nodes
),
desc
=
'
start pos
'
,
position
=
0
):
for
end
in
tqdm
.
tqdm
(
range
(
start
,
num_nodes
),
desc
=
'
end pos
'
,
position
=
1
,
leave
=
False
):
for
start
in
tqdm
.
tqdm
(
range
(
num_nodes
),
desc
=
"
start pos
"
,
position
=
0
):
for
end
in
tqdm
.
tqdm
(
range
(
start
,
num_nodes
),
desc
=
"
end pos
"
,
position
=
1
,
leave
=
False
):
selected_flops
=
[(
node_list
[
i
].
fwd_flop
+
node_list
[
i
].
bwd_flop
)
for
i
in
range
(
start
,
end
+
1
)]
all_compute_cost
[
start
,
end
]
=
sum
(
selected_flops
)
...
...
@@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost
# record start node index for next stage in this partition
f_argmin
=
np
.
full
((
num_stages
+
1
,
num_nodes
+
1
),
-
1
,
dtype
=
np
.
int32
)
f
[
0
,
num_nodes
]
=
0
for
s
in
tqdm
.
tqdm
(
range
(
1
,
num_stages
+
1
),
desc
=
'stage'
,
position
=
2
,
leave
=
False
):
# pylint: disable=too-many-nested-blocks
for
i
in
tqdm
.
tqdm
(
range
(
num_nodes
-
1
,
-
1
,
-
1
),
desc
=
'start node'
,
position
=
3
,
leave
=
False
):
for
k
in
tqdm
.
tqdm
(
range
(
num_nodes
,
i
,
-
1
),
desc
=
'mid node'
,
position
=
4
,
leave
=
False
):
for
s
in
tqdm
.
tqdm
(
range
(
1
,
num_stages
+
1
),
desc
=
"stage"
,
position
=
2
,
leave
=
False
):
# pylint: disable=too-many-nested-blocks
for
i
in
tqdm
.
tqdm
(
range
(
num_nodes
-
1
,
-
1
,
-
1
),
desc
=
"start node"
,
position
=
3
,
leave
=
False
):
for
k
in
tqdm
.
tqdm
(
range
(
num_nodes
,
i
,
-
1
),
desc
=
"mid node"
,
position
=
4
,
leave
=
False
):
stage_cost
=
compute_costs
[
i
,
k
-
1
]
new_cost
=
f
[
s
-
1
,
k
]
+
stage_cost
if
(
stage_cost
<=
max_compute_cost
and
new_cost
<
f
[
s
,
i
]
)
:
if
stage_cost
<=
max_compute_cost
and
new_cost
<
f
[
s
,
i
]:
f
[
s
,
i
]
=
new_cost
f_stage_max
[
s
,
i
]
=
max
(
f_stage_max
[
s
-
1
,
k
],
stage_cost
)
f_argmin
[
s
,
i
]
=
k
...
...
@@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
if
max_compute_cost
-
last_max_compute_cost
<
gap
:
continue
cost
,
solution
=
do_dp_split_gpipe_impl
(
len
(
node_list
),
num_stages
,
num_microbatches
,
compute_costs
,
max_compute_cost
)
cost
,
solution
=
do_dp_split_gpipe_impl
(
len
(
node_list
),
num_stages
,
num_microbatches
,
compute_costs
,
max_compute_cost
)
if
cost
<
best_cost
:
best_cost
=
cost
...
...
@@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
# split_mode:
# 'node': fx_node
# 'block': many fx_nodes construct a block
def
gpipe_dp_split_pass
(
gm
:
torch
.
fx
.
GraphModule
,
pp_size
:
int
,
num_microbatches
:
int
,
mode
=
'
block
'
,
block_limit
=
0.01
):
assert
mode
in
[
'
node
'
,
'
block
'
]
def
gpipe_dp_split_pass
(
gm
:
torch
.
fx
.
GraphModule
,
pp_size
:
int
,
num_microbatches
:
int
,
mode
=
"
block
"
,
block_limit
=
0.01
):
assert
mode
in
[
"
node
"
,
"
block
"
]
# nodes or blocks will be used in partition.
node_list
=
[]
if
mode
==
'
node
'
:
if
mode
==
"
node
"
:
for
node
in
gm
.
graph
.
nodes
:
node_list
.
append
(
node
)
elif
mode
==
'
block
'
:
elif
mode
==
"
block
"
:
node_list
=
construct_blocks
(
gm
,
limit
=
block_limit
)
else
:
pass
...
...
@@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches
best_cost
,
best_solution
=
do_dp_split_gpipe
(
node_list
,
compute_costs
,
pp_size
,
num_microbatches
)
for
(
_
,
next_start_node
)
in
best_solution
:
for
_
,
next_start_node
in
best_solution
:
if
pp_size
<=
1
:
break
node
=
node_list
[
next_start_node
]
with
gm
.
graph
.
inserting_before
(
node
):
split_node
=
gm
.
graph
.
create_node
(
'
call_function
'
,
pipe_split
)
split_node
=
gm
.
graph
.
create_node
(
"
call_function
"
,
pipe_split
)
pp_size
-=
1
# remove block node if possible
if
mode
==
'
block
'
:
if
mode
==
"
block
"
:
remove_blocks
(
gm
)
gm
.
recompile
()
...
...
@@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node
=
list
(
mod_graph
.
nodes
)[
0
]
if
'
tensor_meta
'
not
in
check_node
.
meta
:
if
"
tensor_meta
"
not
in
check_node
.
meta
:
return
balanced_split_pass
(
gm
,
pp_size
)
total_fwd_flop
=
0
...
...
@@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
for
node
in
mod_graph
.
nodes
:
if
pp_size
<=
1
:
break
if
'
pipe_split
'
in
node
.
name
:
if
"
pipe_split
"
in
node
.
name
:
continue
accumulate_fwd_flop
+=
node
.
fwd_flop
if
accumulate_fwd_flop
>=
partition_flop
:
...
...
@@ -199,14 +200,14 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size
-=
1
partition_flop
=
total_fwd_flop
//
pp_size
with
mod_graph
.
inserting_after
(
node
):
split_node
=
mod_graph
.
create_node
(
'
call_function
'
,
pipe_split
)
split_node
=
mod_graph
.
create_node
(
"
call_function
"
,
pipe_split
)
gm
.
recompile
()
return
gm
def
avgnode_split_pass
(
gm
:
torch
.
fx
.
GraphModule
,
pp_size
:
int
):
"""
In avgnode_split_pass, simpl
i
y split graph by node number.
In avgnode_split_pass, simply split graph by node number.
"""
mod_graph
=
gm
.
graph
avg_num_node
=
len
(
mod_graph
.
nodes
)
//
pp_size
...
...
@@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
if
accumulate_num_node
>=
avg_num_node
:
accumulate_num_node
=
0
pp_size
-=
1
if
node
.
next
.
op
==
'
output
'
:
if
node
.
next
.
op
==
"
output
"
:
with
mod_graph
.
inserting_before
(
node
):
split_node
=
mod_graph
.
create_node
(
'
call_function
'
,
pipe_split
)
split_node
=
mod_graph
.
create_node
(
"
call_function
"
,
pipe_split
)
else
:
with
mod_graph
.
inserting_after
(
node
):
split_node
=
mod_graph
.
create_node
(
'
call_function
'
,
pipe_split
)
split_node
=
mod_graph
.
create_node
(
"
call_function
"
,
pipe_split
)
gm
.
recompile
()
return
gm
...
...
@@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size
-=
1
# If the next node is output node, we will insert split annotation before
# node to make sure there is at least one node in last partition.
if
node
.
next
.
op
==
'
output
'
:
if
node
.
next
.
op
==
"
output
"
:
with
mod_graph
.
inserting_before
(
node
):
split_node
=
mod_graph
.
create_node
(
'
call_function
'
,
pipe_split
)
split_node
=
mod_graph
.
create_node
(
"
call_function
"
,
pipe_split
)
else
:
with
mod_graph
.
inserting_after
(
node
):
split_node
=
mod_graph
.
create_node
(
'
call_function
'
,
pipe_split
)
split_node
=
mod_graph
.
create_node
(
"
call_function
"
,
pipe_split
)
if
pp_size
>
1
:
node_counter
=
0
for
node
in
mod_graph
.
nodes
:
if
pp_size
<=
1
:
break
if
node
.
op
==
'
placeholder
'
:
if
node
.
op
==
"
placeholder
"
:
continue
elif
node_counter
==
0
:
node_counter
+=
1
...
...
@@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size
-=
1
node_counter
=
0
with
mod_graph
.
inserting_before
(
node
):
split_node
=
mod_graph
.
create_node
(
'
call_function
'
,
pipe_split
)
split_node
=
mod_graph
.
create_node
(
"
call_function
"
,
pipe_split
)
gm
.
recompile
()
return
gm
...
...
@@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
# To use balanced_split_pass_v2, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node
=
list
(
mod_graph
.
nodes
)[
0
]
if
'
tensor_meta
'
not
in
check_node
.
meta
:
if
"
tensor_meta
"
not
in
check_node
.
meta
:
return
balanced_split_pass
(
gm
,
pp_size
)
total_element_size
=
0
...
...
@@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
for
node
in
mod_graph
.
nodes
:
if
pp_size
<=
1
:
break
if
'
pipe_split
'
in
node
.
name
:
if
"
pipe_split
"
in
node
.
name
:
continue
accumulate_node_size
+=
node
.
node_size
if
accumulate_node_size
>=
partition_size
:
...
...
@@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
pp_size
-=
1
partition_size
=
total_element_size
//
pp_size
with
mod_graph
.
inserting_after
(
node
):
split_node
=
mod_graph
.
create_node
(
'
call_function
'
,
pipe_split
)
split_node
=
mod_graph
.
create_node
(
"
call_function
"
,
pipe_split
)
gm
.
recompile
()
return
gm
...
...
@@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
accumulate_layer_amount
=
0
pp_size
-=
1
with
mod_graph
.
inserting_after
(
node
):
split_node
=
mod_graph
.
create_node
(
'
call_function
'
,
pipe_split
)
split_node
=
mod_graph
.
create_node
(
"
call_function
"
,
pipe_split
)
gm
.
recompile
()
return
gm
...
...
@@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
def
split_callback
(
n
:
torch
.
fx
.
Node
):
nonlocal
part_idx
if
(
n
.
op
,
n
.
target
)
==
(
'
call_function
'
,
pipe_split
):
if
(
n
.
op
,
n
.
target
)
==
(
"
call_function
"
,
pipe_split
):
part_idx
+=
1
return
part_idx
...
...
@@ -355,7 +356,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
for
name
,
submodule
in
split_mod
.
named_modules
():
if
isinstance
(
submodule
,
torch
.
fx
.
GraphModule
):
for
node
in
submodule
.
graph
.
nodes
:
if
(
node
.
op
,
node
.
target
)
==
(
'
call_function
'
,
pipe_split
):
if
(
node
.
op
,
node
.
target
)
==
(
"
call_function
"
,
pipe_split
):
submodule
.
graph
.
erase_node
(
node
)
submodule
.
recompile
()
split_submodules
.
append
(
submodule
)
...
...
Prev
1
…
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment