Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
31d2f03d
Unverified
Commit
31d2f03d
authored
Oct 12, 2022
by
Boyuan Yao
Committed by
GitHub
Oct 12, 2022
Browse files
[autoparallel] fix C version rotor inconsistency (#1691)
parent
363fc286
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
22 deletions
+54
-22
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
+25
-5
colossalai/fx/passes/algorithms/dynamic_programs.c
colossalai/fx/passes/algorithms/dynamic_programs.c
+18
-15
tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py
tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py
+11
-2
No files found.
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
View file @
31d2f03d
...
...
@@ -10,6 +10,9 @@ from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Los
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
_find_nested_ckpt_regions
from
colossalai.logging
import
get_dist_logger
# global vairable to indicate whether the solver is failed
SOLVER_FAILED
=
False
# this is the python compute table code from rotor
# https://gitlab.inria.fr/hiepacs/rotor
...
...
@@ -87,9 +90,17 @@ def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
opt
,
what
=
opt_table
sequence
=
Sequence
(
Function
(
"Persistent"
,
lmax
-
lmin
,
cmem
))
if
opt
[
cmem
][
lmin
][
lmax
]
==
float
(
"inf"
):
raise
ValueError
(
"Can not process this chain from index {lmin} to {lmax} with memory {cmem}"
.
format
(
lmin
=
lmin
,
lmax
=
lmax
,
cmem
=
cmem
))
# using logger to annonce that the solver is failed
logger
=
get_dist_logger
()
logger
.
info
(
"Can not process this chain from index {lmin} to {lmax} with memory {cmem}"
.
format
(
lmin
=
lmin
,
lmax
=
lmax
,
cmem
=
cmem
))
# set global indicater SOLVER_FAILED to True
global
SOLVER_FAILED
SOLVER_FAILED
=
True
return
sequence
if
lmin
==
lmax
:
if
lmin
==
chain
.
length
:
sequence
.
insert
(
Loss
())
...
...
@@ -406,9 +417,18 @@ def solver_rotor(gm: ColoGraphModule,
# found sequence
sequence
=
_rec
(
chain
,
0
,
chain
.
length
,
mem_slots
-
chain
.
cweight
[
0
],
opt_table
)
_annotate_from_sequence
(
sequence
,
node_list
)
# if solver failed, we don't need to annotate the graph
if
not
SOLVER_FAILED
:
_annotate_from_sequence
(
sequence
,
node_list
)
# set __sequence__ attribute to GraphModule
setattr
(
gm
,
"__sequence__"
,
sequence
)
if
SOLVER_FAILED
:
setattr
(
gm
,
"__sequence__"
,
None
)
else
:
setattr
(
gm
,
"__sequence__"
,
sequence
)
# set __opttable__ attribute to GraphModule
setattr
(
gm
,
"__opttable__"
,
opt_table
[
0
])
gm
.
recompile
()
return
gm
colossalai/fx/passes/algorithms/dynamic_programs.c
View file @
31d2f03d
...
...
@@ -94,13 +94,16 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
OPT
(
m
,
i
,
i
)
=
INFINITY
;
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
i
=
0
;
i
<=
chain_length
;
++
i
)
{
long
maxCostFWD
=
0
;
for
(
long
l
=
i
+
1
;
l
<=
chain_length
;
++
l
)
{
long
mmin
=
cw
[
l
+
1
]
+
cw
[
i
+
1
]
+
fwd_tmp
[
i
];
if
(
l
>
i
+
1
)
{
maxCostFWD
=
fmaxl
(
maxCostFWD
,
cw
[
l
-
1
]
+
cw
[
l
]
+
fwd_tmp
[
l
-
1
]);
mmin
=
fmaxl
(
mmin
,
cw
[
l
+
1
]
+
maxCostFWD
);
for
(
long
d
=
1
;
d
<=
chain_length
;
++
d
)
{
for
(
long
i
=
0
;
i
<=
chain_length
-
d
;
++
i
)
{
long
idx
=
i
+
d
;
long
mmin
=
cw
[
idx
+
1
]
+
cw
[
i
+
1
]
+
fwd_tmp
[
i
];
if
(
idx
>
i
+
1
)
{
long
maxCostFWD
=
0
;
for
(
long
j
=
i
+
1
;
j
<
idx
;
j
++
)
{
maxCostFWD
=
fmaxl
(
maxCostFWD
,
cw
[
j
]
+
cw
[
j
+
1
]
+
fwd_tmp
[
j
]);
}
mmin
=
fmaxl
(
mmin
,
cw
[
idx
+
1
]
+
maxCostFWD
);
}
if
((
m
>=
mmin
))
{
long
bestLeaf
=
-
1
;
...
...
@@ -108,10 +111,10 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
double
bestLeafCost
=
INFINITY
;
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
/// i+1
for
(
long
j
=
i
+
1
;
j
<=
l
;
++
j
)
{
for
(
long
j
=
i
+
1
;
j
<=
idx
;
++
j
)
{
sumFw
+=
fw
[
j
-
1
];
if
(
m
>=
cw
[
j
])
{
double
cost
=
sumFw
+
OPT
(
m
-
cw
[
j
],
j
,
l
)
+
OPT
(
m
,
i
,
j
-
1
);
double
cost
=
sumFw
+
OPT
(
m
-
cw
[
j
],
j
,
idx
)
+
OPT
(
m
,
i
,
j
-
1
);
if
(
cost
<
bestLeafCost
)
{
bestLeafCost
=
cost
;
bestLeaf
=
j
;
...
...
@@ -120,16 +123,16 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
}
double
chainCost
=
INFINITY
;
if
(
m
>=
cbw
[
i
+
1
])
chainCost
=
OPT
(
m
,
i
,
i
)
+
OPT
(
m
-
cbw
[
i
+
1
],
i
+
1
,
l
);
chainCost
=
OPT
(
m
,
i
,
i
)
+
OPT
(
m
-
cbw
[
i
+
1
],
i
+
1
,
idx
);
if
(
bestLeafCost
<=
chainCost
)
{
OPT
(
m
,
i
,
l
)
=
bestLeafCost
;
WHAT
(
m
,
i
,
l
)
=
bestLeaf
;
OPT
(
m
,
i
,
idx
)
=
bestLeafCost
;
WHAT
(
m
,
i
,
idx
)
=
bestLeaf
;
}
else
{
OPT
(
m
,
i
,
l
)
=
chainCost
;
WHAT
(
m
,
i
,
l
)
=
-
1
;
OPT
(
m
,
i
,
idx
)
=
chainCost
;
WHAT
(
m
,
i
,
idx
)
=
-
1
;
}
}
else
OPT
(
m
,
i
,
l
)
=
INFINITY
;
OPT
(
m
,
i
,
idx
)
=
INFINITY
;
}
}
...
...
tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py
View file @
31d2f03d
...
...
@@ -26,7 +26,7 @@ except:
def
_run_C_solver_consistency_test
(
rank
=
0
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
for
M
,
mem_budget
in
[(
tm
.
resnet
18
,
2
000
),
(
tm
.
resnet50
,
80
0
0
)]:
for
M
,
mem_budget
in
[(
tm
.
resnet
50
,
4
000
),
(
tm
.
densenet121
,
80
8
0
)]:
model
=
M
()
data
=
torch
.
rand
(
128
,
3
,
224
,
224
,
device
=
'meta'
)
...
...
@@ -41,15 +41,24 @@ def _run_C_solver_consistency_test(rank=0):
# python solver
gm
=
solver_rotor
(
gm
,
data_meta
,
mem_budget
*
1024
*
1024
,
force_python
=
True
)
sequence_python
:
Sequence
=
copy
.
deepcopy
(
gm
.
__sequence__
)
opt_python
=
copy
.
deepcopy
(
gm
.
__opttable__
)
# C solver
gm
=
solver_rotor
(
gm
,
data_meta
,
mem_budget
*
1024
*
1024
)
sequence_C
:
Sequence
=
copy
.
deepcopy
(
gm
.
__sequence__
)
opt_C
=
copy
.
deepcopy
(
gm
.
__opttable__
)
# make sure the opt_tables are the same
for
m
in
range
(
len
(
opt_python
)):
for
d
in
range
(
1
,
len
(
opt_python
[
0
])):
for
i
in
range
(
len
(
opt_python
[
0
])
-
d
):
assert
opt_python
[
m
][
i
][
i
+
d
]
==
opt_C
[
m
][
i
][
i
+
d
],
\
f
"item (
{
m
}
,
{
i
}
,
{
i
+
d
}
) is not consistent with python version!
\n
python version:
{
opt_python
[
m
][
i
][
i
+
d
]
}
\n
C version:
{
opt_C
[
m
][
i
][
i
+
d
]
}
"
sequence_python
=
sequence_python
.
list_operations
()
sequence_C
=
sequence_C
.
list_operations
()
# make sure the s
olution
s are the same
# make sure the s
equence
s are the same
assert
len
(
sequence_python
)
==
len
(
sequence_C
)
and
\
all
(
python_op
.
__repr__
()
==
C_op
.
__repr__
()
for
(
python_op
,
C_op
)
in
zip
(
sequence_python
,
sequence_C
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment