Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cc55ff0a
Unverified
Commit
cc55ff0a
authored
Nov 10, 2022
by
Super Daniel
Committed by
GitHub
Nov 10, 2022
Browse files
[autoparallel] user-friendly API for CheckpointSolver. (#1879)
Merge for SC tutorial
parent
448248b2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
17 deletions
+16
-17
colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
+10
-6
colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
+6
-11
No files found.
colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
View file @
cc55ff0a
...
...
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from
copy
import
deepcopy
from
typing
import
Any
,
List
import
torch
from
torch.fx
import
Graph
,
Node
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
ActivationCheckpointCodeGen
...
...
@@ -17,13 +18,17 @@ def _copy_output(src: Graph, dst: Graph):
n_dst
.
meta
=
n_src
.
meta
def
_get_param_size
(
module
:
torch
.
nn
.
Module
):
"""Get the size of the parameters in the module"""
return
sum
([
p
.
numel
()
*
torch
.
tensor
([],
dtype
=
p
.
dtype
).
element_size
()
for
p
in
module
.
parameters
()])
class
CheckpointSolverBase
(
ABC
):
def
__init__
(
self
,
graph
:
Graph
,
memory_budget
:
float
=
-
1.0
,
parameter_size
:
float
=
0
,
free_memory
:
float
=
-
1.0
,
requires_linearize
:
bool
=
False
,
cnode
:
List
[
str
]
=
None
,
):
...
...
@@ -37,8 +42,7 @@ class CheckpointSolverBase(ABC):
Args:
graph (Graph): The computing graph to be optimized.
memory_budget (float): Memory constraint for the solution.
parameter_size (float): The size of parameter of this model. Use `parameter_size(model)` to estimate.
free_memory (float): Memory constraint for the solution.
requires_linearize (bool): Whether the graph needs to be linearized.
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
...
...
@@ -58,8 +62,8 @@ class CheckpointSolverBase(ABC):
raise
RuntimeError
(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!"
)
self
.
memory
_budget
=
memory_budget
self
.
parameter_size
=
parameter_size
self
.
free_
memory
=
free_memory
self
.
parameter_size
=
_get_param_size
(
self
.
graph
.
owning_module
)
self
.
cnode
=
cnode
self
.
requires_linearize
=
requires_linearize
if
self
.
requires_linearize
:
...
...
colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
View file @
cc55ff0a
...
...
@@ -22,12 +22,7 @@ __all__ = ['CheckpointSolverRotor']
class
CheckpointSolverRotor
(
CheckpointSolverBase
):
def
__init__
(
self
,
graph
:
Graph
,
memory_budget
:
float
=
-
1
,
parameter_size
:
float
=
0
,
cnode
:
List
[
str
]
=
None
,
memory_slots
:
int
=
500
):
def
__init__
(
self
,
graph
:
Graph
,
free_memory
:
float
=
-
1
,
cnode
:
List
[
str
]
=
None
,
memory_slots
:
int
=
500
):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
...
...
@@ -36,22 +31,22 @@ class CheckpointSolverRotor(CheckpointSolverBase):
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverRotor`:
>>> solver = CheckpointSolverRotor(gm.graph, memory
_budget=memory_budget, parameter_size=parameter_size
)
>>> solver = CheckpointSolverRotor(gm.graph,
free_
memory
=torch.cuda.mem_get_info(device=0)[0]
)
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
>>> gm.graph = rotor_graph # set the graph to a new graph
Args:
graph (Graph): The computing graph to be optimized.
memory
_budget
(float, optional): Memory constraint for the solution, unit is byte.
parameter_size (float, optional): The size of parameter of this model, unit is byte. Use `parameter_size(model)` to estimate
.
free_
memory (float, optional): Memory constraint for the solution, unit is byte.
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1
.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
"""
super
().
__init__
(
graph
,
memory
_budget
,
parameter_size
,
True
,
cnode
)
super
().
__init__
(
graph
,
free_
memory
,
True
,
cnode
)
self
.
memory_slots
=
memory_slots
# construct chain
unit
=
self
.
memory
_budget
//
self
.
memory_slots
unit
=
self
.
free_
memory
//
self
.
memory_slots
self
.
chain
=
self
.
_construct_chain
(
self
.
graph
,
self
.
node_list
)
self
.
chain
.
discretize_all
(
unit
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment