Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6d559ea6
Unverified
Commit
6d559ea6
authored
Nov 10, 2022
by
Super Daniel
Committed by
GitHub
Nov 10, 2022
Browse files
[sc] add examples for auto checkpoint. (#1880)
parent
51597f6a
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
943 additions
and
0 deletions
+943
-0
examples/tutorial/auto_parallel/auto_ckpt_demo.ipynb
examples/tutorial/auto_parallel/auto_ckpt_demo.ipynb
+878
-0
examples/tutorial/auto_parallel/bench_utils.py
examples/tutorial/auto_parallel/bench_utils.py
+65
-0
No files found.
examples/tutorial/auto_parallel/auto_ckpt_demo.ipynb
0 → 100644
View file @
6d559ea6
This diff is collapsed.
Click to expand it.
examples/tutorial/auto_parallel/bench_utils.py
0 → 100644
View file @
6d559ea6
import
time
from
functools
import
partial
from
typing
import
Callable
,
Tuple
import
numpy
as
np
import
torch
import
torchvision.models
as
tm
from
colossalai.auto_parallel.checkpoint
import
CheckpointSolverRotor
from
colossalai.fx
import
metainfo_trace
def
bench
(
gm
:
torch
.
fx
.
GraphModule
,
criterion
:
torch
.
nn
.
Module
,
data_gen
:
Callable
,
num_steps
:
int
=
5
):
gm
.
train
()
gm
.
cuda
()
step_time
=
float
(
'inf'
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
cached
=
torch
.
cuda
.
max_memory_allocated
(
device
=
"cuda"
)
try
:
for
_
in
range
(
num_steps
):
args
,
label
=
data_gen
()
output
,
loss
=
None
,
None
torch
.
cuda
.
synchronize
(
device
=
"cuda"
)
start
=
time
.
time
()
output
=
gm
(
*
args
)
loss
=
criterion
(
output
,
label
)
loss
.
backward
()
torch
.
cuda
.
synchronize
(
device
=
"cuda"
)
step_time
=
min
(
step_time
,
time
.
time
()
-
start
)
for
child
in
gm
.
children
():
for
param
in
child
.
parameters
():
param
.
grad
=
None
del
args
,
label
,
output
,
loss
except
:
del
args
,
label
,
output
,
loss
gm
.
to
(
"cpu"
)
torch
.
cuda
.
empty_cache
()
return
(
torch
.
cuda
.
max_memory_allocated
(
device
=
"cuda"
)
-
cached
)
/
1024
**
2
,
step_time
*
1.0e3
def
bench_rotor
(
gm
:
torch
.
fx
.
GraphModule
,
criterion
:
torch
.
nn
.
Module
,
data_gen
:
Callable
,
num_steps
:
int
=
5
,
sample_points
:
int
=
20
,
free_memory
:
int
=
torch
.
cuda
.
mem_get_info
()[
0
]):
peak_hist
,
step_hist
=
[],
[]
for
budget
in
np
.
linspace
(
free_memory
//
5
,
free_memory
,
sample_points
):
gm
=
metainfo_trace
(
gm
,
*
data_gen
()[
0
])
solver
=
CheckpointSolverRotor
(
gm
.
graph
,
free_memory
=
budget
)
try
:
gm
.
graph
=
solver
.
solve
()
peak_memory
,
step_time
=
bench
(
gm
,
criterion
,
partial
(
data_gen
,
batch_size
=
2048
,
shape
=
(
3
,
224
,
224
)),
num_steps
=
num_steps
)
except
:
peak_memory
,
step_time
=
budget
/
1024
**
2
,
float
(
'inf'
)
peak_hist
.
append
(
peak_memory
)
step_hist
.
append
(
step_time
)
return
peak_hist
,
step_hist
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment