Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cd5cf2bc
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "5fcd7795cd646205cc90785c398a02c8ac475b69"
Unverified
Commit
cd5cf2bc
authored
Sep 15, 2022
by
Super Daniel
Committed by
GitHub
Sep 15, 2022
Browse files
[fx/tuning] tune performance on rotor with meta info. (#1599)
parent
a7cda6f5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
96 additions
and
107 deletions
+96
-107
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
+14
-69
colossalai/fx/passes/algorithms/linearize.py
colossalai/fx/passes/algorithms/linearize.py
+5
-1
colossalai/fx/profiler/__init__.py
colossalai/fx/profiler/__init__.py
+1
-1
colossalai/fx/profiler/dataflow.py
colossalai/fx/profiler/dataflow.py
+16
-12
colossalai/fx/profiler/memory.py
colossalai/fx/profiler/memory.py
+34
-2
colossalai/fx/profiler/opcount.py
colossalai/fx/profiler/opcount.py
+1
-0
colossalai/fx/profiler/profiler.py
colossalai/fx/profiler/profiler.py
+25
-22
No files found.
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
View file @
cd5cf2bc
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
import
torch
from
torch.fx
import
Node
from
torch.fx
import
GraphModule
,
Node
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.profiler
import
parameter_size
from
colossalai.fx.profiler
import
activation_size
,
parameter_size
import
math
import
math
from
.linearize
import
linearize
from
.linearize
import
linearize
from
.utils
import
*
from
.utils
import
*
...
@@ -31,7 +30,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
...
@@ -31,7 +30,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
# Build table
# Build table
opt
=
[[{}
for
_
in
range
(
chain
.
length
+
1
)]
for
_
in
range
(
mmax
+
1
)]
opt
=
[[{}
for
_
in
range
(
chain
.
length
+
1
)]
for
_
in
range
(
mmax
+
1
)]
what
=
[[{}
for
_
in
range
(
chain
.
length
+
1
)]
for
_
in
range
(
mmax
+
1
)]
what
=
[[{}
for
_
in
range
(
chain
.
length
+
1
)]
for
_
in
range
(
mmax
+
1
)]
#
#
Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
# Initialize borders of the tables for lmax-lmin = 0
for
m
in
range
(
mmax
+
1
):
for
m
in
range
(
mmax
+
1
):
...
@@ -115,43 +114,6 @@ def _discretize(mem_unit, values):
...
@@ -115,43 +114,6 @@ def _discretize(mem_unit, values):
return
[
math
.
ceil
(
value
/
mem_unit
)
for
value
in
values
]
return
[
math
.
ceil
(
value
/
mem_unit
)
for
value
in
values
]
def
_compute_size
(
obj
:
torch
.
Tensor
)
->
int
:
return
obj
.
numel
()
*
obj
.
element_size
()
def
_compute_output_size
(
node
:
List
[
Node
])
->
int
:
"""Compute the output size of a node
Args:
node (List[Node]): node, list of torch.fx.Node
Returns:
int: output size
"""
return
node
[
-
1
].
meta
[
'tensor_meta'
].
numel
*
torch
.
tensor
([],
dtype
=
node
[
-
1
].
meta
[
'tensor_meta'
].
dtype
).
element_size
()
def
_get_inplace
(
node
:
Node
)
->
bool
:
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
is_inplace
=
False
if
node
.
op
==
"call_function"
:
is_inplace
=
node
.
kwargs
.
get
(
"inplace"
,
False
)
elif
node
.
op
==
"call_module"
:
is_inplace
=
getattr
(
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
),
"inplace"
,
False
)
return
is_inplace
def
_fwd_xbar
(
node
:
List
[
Node
])
->
int
:
def
_fwd_xbar
(
node
:
List
[
Node
])
->
int
:
"""Get the forward xbar of a node
"""Get the forward xbar of a node
...
@@ -221,46 +183,33 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
...
@@ -221,46 +183,33 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
for
k
,
v
in
deps
.
items
():
for
k
,
v
in
deps
.
items
():
if
v
>
0
:
if
v
>
0
:
deps_size
+=
k
.
meta
[
'bwd_mem_out'
]
deps_size
+=
k
.
meta
[
'bwd_mem_out'
]
if
v
==
float
(
'-inf'
):
deps_size
-=
k
.
meta
[
'fwd_mem_tmp'
]
+
k
.
meta
[
'fwd_mem_out'
]
return
deps_size
return
deps_size
bwd_mem_tmp
=
0
bwd_mem_tmp
=
0
deps
=
{}
deps
=
{}
# add all the users for last node into deps,
# as those nodes' gradient out will be stored in memory
for
child
in
node
[
-
1
].
users
:
deps
[
child
]
=
1
for
n
in
reversed
(
node
):
for
n
in
reversed
(
node
):
deps
[
n
]
=
len
(
n
.
all_input_nodes
)
bwd_mem_tmp
=
max
(
bwd_mem_tmp
,
_get_deps_size
()
+
n
.
meta
[
'bwd_mem_tmp'
])
bwd_mem_tmp
=
max
(
bwd_mem_tmp
,
_get_deps_size
()
+
n
.
meta
[
'bwd_mem_tmp'
])
deps
[
n
]
=
len
(
n
.
all_input_nodes
)
for
child
in
n
.
users
:
for
child
in
n
.
users
:
if
child
in
deps
:
if
child
in
deps
:
deps
[
child
]
-=
1
deps
[
child
]
-=
1
if
deps
[
child
]
<=
0
:
for
key
in
list
(
deps
.
keys
()):
deps
[
child
]
=
float
(
'-inf'
)
# free
if
deps
[
key
]
==
0
:
del
deps
[
key
]
return
bwd_mem_tmp
return
bwd_mem_tmp
def
_construct_chain
(
node_list
:
List
[
List
[
Node
]],
data
,
mem_unit
:
int
)
->
Chain
:
def
_construct_chain
(
node_list
:
List
[
List
[
Node
]],
input
,
mem_unit
:
int
)
->
Chain
:
fwd_time
=
[]
fwd_time
=
[]
bwd_time
=
[]
bwd_time
=
[]
xbar_sizes
=
[
activation_size
(
input
)]
if
isinstance
(
data
,
torch
.
Tensor
):
x_sizes
=
[
activation_size
(
input
)]
xbar_sizes
=
[
_compute_size
(
data
)]
x_sizes
=
[
_compute_size
(
data
)]
elif
isinstance
(
data
,
list
)
or
isinstance
(
data
,
tuple
):
xbar_sizes
=
[
sum
([
_compute_size
(
obj
)
for
obj
in
data
])]
x_sizes
=
[
sum
([
_compute_size
(
obj
)
for
obj
in
data
])]
elif
isinstance
(
data
,
dict
):
xbar_sizes
=
[
sum
([
_compute_size
(
obj
)
for
obj
in
data
.
values
()])]
x_sizes
=
[
sum
([
_compute_size
(
obj
)
for
obj
in
data
.
values
()])]
# currently we can't get the temp memory needed in fwd
# currently we can't get the temp memory needed in fwd
tmp_fwd
=
[
0
]
*
len
(
node_list
)
tmp_fwd
=
[
0
]
*
len
(
node_list
)
tmp_bwd
=
[]
tmp_bwd
=
[]
...
@@ -268,14 +217,10 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
...
@@ -268,14 +217,10 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
for
idx
,
node
in
enumerate
(
node_list
):
for
idx
,
node
in
enumerate
(
node_list
):
fwd_time
.
append
(
_fwd_time
(
node
))
fwd_time
.
append
(
_fwd_time
(
node
))
bwd_time
.
append
(
_bwd_time
(
node
))
bwd_time
.
append
(
_bwd_time
(
node
))
x_sizes
.
append
(
_compute_output_size
(
node
)
)
x_sizes
.
append
(
node
[
-
1
].
meta
[
'fwd_mem_out'
]
)
xbar_sizes
.
append
(
max
(
x_sizes
[
-
1
],
_fwd_xbar
(
node
)))
xbar_sizes
.
append
(
max
(
x_sizes
[
-
1
],
_fwd_xbar
(
node
)))
tmp_bwd
.
append
(
_get_bwd_mem_tmp
(
node
))
tmp_bwd
.
append
(
_get_bwd_mem_tmp
(
node
))
# if a node with only one inplace op, we need to let x_bar = 0
if
len
(
node
)
==
1
and
_get_inplace
(
node
[
0
]):
xbar_sizes
[
-
1
]
=
0
bwd_time
.
append
(
0
)
bwd_time
.
append
(
0
)
# currently we view loss backward temp as zero
# currently we view loss backward temp as zero
...
@@ -381,7 +326,7 @@ def solver_rotor(gm: ColoGraphModule,
...
@@ -381,7 +326,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit
:
int
,
mem_limit
:
int
,
mem_slots
:
int
=
500
,
mem_slots
:
int
=
500
,
cnode
:
List
[
str
]
=
None
,
cnode
:
List
[
str
]
=
None
,
eps
:
float
=
0.0
2
)
->
ColoGraphModule
:
eps
:
float
=
0.0
)
->
ColoGraphModule
:
"""solver that automatically find activation checkpoint in rotor's manner
"""solver that automatically find activation checkpoint in rotor's manner
Args:
Args:
...
@@ -390,7 +335,7 @@ def solver_rotor(gm: ColoGraphModule,
...
@@ -390,7 +335,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit (int): memory budget in Byte.
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
cnode (List[Node], optional): common node list for linearize. Defaults to None.
cnode (List[Node], optional): common node list for linearize. Defaults to None.
eps (float): epsilon for memory decay. Defaults to 0.0
2
eps (float): epsilon for memory decay. Defaults to 0.0
Returns:
Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
...
...
colossalai/fx/passes/algorithms/linearize.py
View file @
cd5cf2bc
from
typing
import
List
,
Any
from
typing
import
List
,
Any
from
torch.fx
import
GraphModule
,
Node
from
torch.fx
import
GraphModule
,
Node
from
colossalai.fx.profiler
import
is_inplace
# Common nodes are type of nodes that could be seen as attributes and remain
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
# unchanged throughout the whole model, it will be used several times by
...
@@ -41,6 +42,9 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
...
@@ -41,6 +42,9 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
Returns:
Returns:
List[List[Node]]: List of list, each inside list of Node presents
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
the actual 'node' in linearized manner.
Remarks:
We merge the inplace ops into the previous node.
"""
"""
def
_is_sink
()
->
bool
:
def
_is_sink
()
->
bool
:
...
@@ -50,7 +54,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
...
@@ -50,7 +54,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
bool
bool
"""
"""
return
not
sum
([
v
for
_
,
v
in
deps
.
items
()])
return
not
sum
([
v
for
_
,
v
in
deps
.
items
()])
and
not
any
(
map
(
is_inplace
,
n
.
users
))
# make sure that item in cnode is valid
# make sure that item in cnode is valid
if
cnode
:
if
cnode
:
...
...
colossalai/fx/profiler/__init__.py
View file @
cd5cf2bc
...
@@ -7,4 +7,4 @@ else:
...
@@ -7,4 +7,4 @@ else:
from
.experimental
import
meta_profiler_function
,
meta_profiler_module
,
profile_function
,
profile_method
,
profile_module
from
.experimental
import
meta_profiler_function
,
meta_profiler_module
,
profile_function
,
profile_method
,
profile_module
from
.dataflow
import
GraphInfo
from
.dataflow
import
GraphInfo
from
.memory
import
parameter_size
,
activation_size
from
.memory
import
parameter_size
,
activation_size
,
is_inplace
colossalai/fx/profiler/dataflow.py
View file @
cd5cf2bc
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
from
enum
import
Enum
from
functools
import
partial
from
typing
import
Dict
from
typing
import
Dict
from
torch.fx
import
Graph
,
Node
from
torch.fx
import
Graph
,
Node
from
.memory
import
activation_size
from
.memory
import
activation_size
,
is_inplace
from
.
import
META_COMPATIBILITY
if
META_COMPATIBILITY
:
from
.memory
import
NORMALIZATION_ATEN
,
CLONE_ATEN
class
Phase
(
Enum
):
class
Phase
(
Enum
):
FORWARD
=
0
FORWARD
=
0
LOSS
=
1
BACKWARD
=
1
BACKWARD
=
2
PLACEHOLDER
=
2
PLACEHOLDER
=
3
@
dataclass
@
dataclass
...
@@ -86,8 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
...
@@ -86,8 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
def
_peak_memory
(
deps
:
Dict
[
Node
,
int
]):
def
_peak_memory
(
deps
:
Dict
[
Node
,
int
]):
peak_mem
=
0
peak_mem
=
0
for
k
,
v
in
deps
.
items
():
for
k
,
v
in
deps
.
items
():
if
v
>
0
:
if
v
>
0
and
is_phase
(
k
,
Phase
.
BACKWARD
)
and
not
any
(
map
(
is_inplace
,
k
.
users
))
:
peak_mem
+=
activation_size
(
k
.
meta
[
'out'
])
peak_mem
+=
activation_size
(
k
.
meta
[
'out'
])
if
v
<=
float
(
'-inf'
)
and
is_saved
(
k
)
and
(
k
.
target
not
in
NORMALIZATION_ATEN
):
peak_mem
-=
activation_size
(
k
.
meta
[
'out'
])
return
peak_mem
return
peak_mem
# deps is used to track all the memory dependencies of the graph.
# deps is used to track all the memory dependencies of the graph.
...
@@ -96,7 +99,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
...
@@ -96,7 +99,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
for
n
in
graph
.
nodes
:
for
n
in
graph
.
nodes
:
n
:
Node
n
:
Node
if
is_saved
(
n
)
and
not
any
(
map
(
partial
(
is_phase
,
phase
=
Phase
.
LOSS
)
,
n
.
users
)):
if
is_saved
(
n
)
and
(
n
.
target
not
in
NORMALIZATION_ATEN
)
or
any
(
map
(
lambda
x
:
x
.
target
in
CLONE_ATEN
,
n
.
users
)):
# A forward tensor who is marked `save` but is not
# A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward.
# an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
...
@@ -110,13 +113,14 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
...
@@ -110,13 +113,14 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
graph_info
.
fwd_mem_tmp
+=
activation_size
(
n
.
meta
[
'out'
])
graph_info
.
fwd_mem_tmp
+=
activation_size
(
n
.
meta
[
'out'
])
elif
is_phase
(
n
,
Phase
.
BACKWARD
):
elif
is_phase
(
n
,
Phase
.
BACKWARD
):
if
len
(
n
.
users
):
if
len
(
n
.
users
):
# liveness analysis is only used in backward
deps
[
n
]
=
len
(
n
.
users
)
graph_info
.
bwd_mem_tmp
=
max
(
graph_info
.
bwd_mem_tmp
,
_peak_memory
(
deps
))
graph_info
.
bwd_mem_tmp
=
max
(
graph_info
.
bwd_mem_tmp
,
_peak_memory
(
deps
))
for
input_n
in
n
.
all_input_nodes
:
if
input_n
in
deps
:
deps
[
input_n
]
-=
1
else
:
else
:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
# basically a backward node without user is a `grad_out` node
graph_info
.
bwd_mem_out
+=
activation_size
(
n
.
meta
[
'out'
])
graph_info
.
bwd_mem_out
+=
activation_size
(
n
.
meta
[
'out'
])
for
input_n
in
n
.
all_input_nodes
:
if
input_n
in
deps
:
deps
[
input_n
]
-=
1
if
deps
[
input_n
]
<=
0
:
deps
[
input_n
]
=
float
(
'-inf'
)
return
graph_info
return
graph_info
colossalai/fx/profiler/memory.py
View file @
cd5cf2bc
import
torch
import
torch
from
torch.fx
import
Node
from
typing
import
Union
,
Dict
,
List
,
Tuple
from
typing
import
Union
,
Dict
,
List
,
Tuple
from
operator
import
add
,
floordiv
,
getitem
,
mul
,
neg
,
setitem
,
sub
,
pos
from
operator
import
add
,
floordiv
,
getitem
,
mul
,
neg
,
setitem
,
sub
,
pos
from
.
import
META_COMPATIBILITY
from
.
import
META_COMPATIBILITY
__all__
=
[
'activation_size'
,
'parameter_size'
]
__all__
=
[
'activation_size'
,
'parameter_size'
,
'is_inplace'
]
if
META_COMPATIBILITY
:
if
META_COMPATIBILITY
:
aten
=
torch
.
ops
.
aten
aten
=
torch
.
ops
.
aten
...
@@ -21,6 +22,7 @@ if META_COMPATIBILITY:
...
@@ -21,6 +22,7 @@ if META_COMPATIBILITY:
aten
.
bernoulli_
.
float
,
aten
.
bernoulli_
.
float
,
# inplace reshaping
# inplace reshaping
aten
.
copy_
.
default
,
aten
.
detach
.
default
,
aten
.
detach
.
default
,
aten
.
t
.
default
,
aten
.
t
.
default
,
aten
.
transpose
.
int
,
aten
.
transpose
.
int
,
...
@@ -28,7 +30,17 @@ if META_COMPATIBILITY:
...
@@ -28,7 +30,17 @@ if META_COMPATIBILITY:
aten
.
_unsafe_view
.
default
,
aten
.
_unsafe_view
.
default
,
]
]
__all__
+=
[
'INPLACE_ATEN'
,
'WEIRD_OPS'
]
NORMALIZATION_ATEN
=
[
aten
.
native_batch_norm
.
default
,
aten
.
native_layer_norm
.
default
,
# aten.max_pool2d_with_indices.default,
]
CLONE_ATEN
=
[
aten
.
clone
.
default
,
]
__all__
+=
[
'INPLACE_ATEN'
,
'WEIRD_OPS'
,
'NORMALIZATION_ATEN'
,
'CLONE_ATEN'
]
else
:
else
:
# TODO fill out the inplace ops
# TODO fill out the inplace ops
...
@@ -106,3 +118,23 @@ def parameter_size(mod: torch.nn.Module) -> int:
...
@@ -106,3 +118,23 @@ def parameter_size(mod: torch.nn.Module) -> int:
for
param
in
mod
.
parameters
():
for
param
in
mod
.
parameters
():
param_size
+=
param
.
numel
()
*
torch
.
tensor
([],
dtype
=
param
.
dtype
).
element_size
()
param_size
+=
param
.
numel
()
*
torch
.
tensor
([],
dtype
=
param
.
dtype
).
element_size
()
return
param_size
return
param_size
def
is_inplace
(
n
:
Node
):
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
inplace
=
False
if
n
.
op
==
"call_function"
:
inplace
=
n
.
kwargs
.
get
(
"inplace"
,
False
)
if
META_COMPATIBILITY
and
n
.
target
in
INPLACE_ATEN
:
inplace
=
True
elif
n
.
op
==
"call_module"
:
inplace
=
getattr
(
n
.
graph
.
owning_module
.
get_submodule
(
n
.
target
),
"inplace"
,
False
)
return
inplace
colossalai/fx/profiler/opcount.py
View file @
cd5cf2bc
...
@@ -222,6 +222,7 @@ flop_mapping = {
...
@@ -222,6 +222,7 @@ flop_mapping = {
aten
.
_adaptive_avg_pool2d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
_adaptive_avg_pool2d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
_adaptive_avg_pool3d
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
_adaptive_avg_pool3d
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
_adaptive_avg_pool3d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
_adaptive_avg_pool3d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
embedding_dense_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
}
}
elementwise_flop_aten
=
[
elementwise_flop_aten
=
[
...
...
colossalai/fx/profiler/profiler.py
View file @
cd5cf2bc
from
dataclasses
import
dataclass
from
enum
import
auto
from
typing
import
Callable
,
Any
,
Dict
,
Tuple
from
typing
import
Callable
,
Any
,
Dict
,
Tuple
import
torch
import
torch
from
torch.fx
import
Graph
,
Node
from
torch.fx
import
Graph
,
Node
from
torch.fx.node
import
Argument
,
Target
from
torch.fx.node
import
Argument
,
Target
from
torch.utils._pytree
import
tree_map
from
torch.utils._pytree
import
tree_map
from
.dataflow
import
GraphInfo
,
autograd_graph_analysis
,
Phase
from
.dataflow
import
GraphInfo
,
autograd_graph_analysis
,
Phase
from
.memory
import
WEIRD_OPS
,
activation_size
from
.memory
import
WEIRD_OPS
from
.tensor
import
MetaTensor
from
.tensor
import
MetaTensor
from
.opcount
import
flop_mapping
from
.opcount
import
flop_mapping
...
@@ -23,7 +21,7 @@ def is_autogradable(x):
...
@@ -23,7 +21,7 @@ def is_autogradable(x):
return
isinstance
(
x
,
torch
.
Tensor
)
and
x
.
is_floating_point
()
return
isinstance
(
x
,
torch
.
Tensor
)
and
x
.
is_floating_point
()
def
_profile
(
target
:
Callable
,
*
args
,
inplace
=
False
,
**
kwargs
)
->
Tuple
[
Any
,
...]:
def
_profile
(
target
:
Callable
,
*
args
,
**
kwargs
)
->
Tuple
[
Any
,
...]:
"""
"""
Profile a Callable function with args and kwargs.
Profile a Callable function with args and kwargs.
...
@@ -42,7 +40,6 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
...
@@ -42,7 +40,6 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# `flop_count`` serves as a global dictionary to store results.
# `flop_count`` serves as a global dictionary to store results.
flop_count
=
{
flop_count
=
{
Phase
.
FORWARD
:
0
,
Phase
.
FORWARD
:
0
,
Phase
.
LOSS
:
0
,
Phase
.
BACKWARD
:
0
,
Phase
.
BACKWARD
:
0
,
}
}
...
@@ -71,6 +68,10 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
...
@@ -71,6 +68,10 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
kwargs_node
=
tree_map
(
get_node
,
kwargs
)
kwargs_node
=
tree_map
(
get_node
,
kwargs
)
node
=
subgraph
.
create_node
(
'call_function'
,
func
,
args_node
,
kwargs_node
)
node
=
subgraph
.
create_node
(
'call_function'
,
func
,
args_node
,
kwargs_node
)
# do not allocate on `cpu`
if
'device'
in
kwargs
:
kwargs
[
'device'
]
=
'meta'
def
unwrap
(
x
):
def
unwrap
(
x
):
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
if
isinstance
(
x
,
torch
.
Tensor
)
and
not
hasattr
(
x
,
'_tensor'
):
if
isinstance
(
x
,
torch
.
Tensor
)
and
not
hasattr
(
x
,
'_tensor'
):
...
@@ -101,13 +102,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
...
@@ -101,13 +102,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
if
target
not
in
WEIRD_OPS
:
if
target
not
in
WEIRD_OPS
:
def
wrap
(
x
):
def
wrap
(
x
):
return
FlopTensor
(
x
.
detach
().
requires_grad_
(
return
FlopTensor
(
True
))
if
is_autogradable
(
x
)
and
not
inplace
and
not
hasattr
(
x
,
'_tensor'
)
else
x
x
.
detach
().
requires_grad_
(
True
))
if
is_autogradable
(
x
)
and
not
hasattr
(
x
,
'_tensor'
)
else
x
else
:
else
:
def
wrap
(
x
):
def
wrap
(
x
):
return
FlopTensor
(
x
.
detach
().
requires_grad_
(
return
FlopTensor
(
False
))
if
is_autogradable
(
x
)
and
not
inplace
and
not
hasattr
(
x
,
'_tensor'
)
else
x
x
.
detach
().
requires_grad_
(
False
))
if
is_autogradable
(
x
)
and
not
hasattr
(
x
,
'_tensor'
)
else
x
# Basically, we need to detach the args and kwargs from the outer graph.
# Basically, we need to detach the args and kwargs from the outer graph.
args
=
tree_map
(
wrap
,
args
)
args
=
tree_map
(
wrap
,
args
)
...
@@ -125,7 +126,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
...
@@ -125,7 +126,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
tree_map
(
set_placeholder
,
kwargs
)
tree_map
(
set_placeholder
,
kwargs
)
def
pack
(
x
):
def
pack
(
x
):
if
isinstance
(
x
,
FlopTensor
):
if
isinstance
(
x
,
FlopTensor
)
and
not
isinstance
(
x
,
torch
.
nn
.
Parameter
)
:
x
.
_node
.
meta
[
'saved'
]
=
True
x
.
_node
.
meta
[
'saved'
]
=
True
return
x
return
x
...
@@ -143,13 +144,15 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
...
@@ -143,13 +144,15 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
else
:
else
:
out
=
target
(
*
args
,
**
kwargs
)
out
=
target
(
*
args
,
**
kwargs
)
# If the output is not a floating point `torch.Tensor` or it does not
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
# requires grad, then we should not run backward for this node.
if
is_autogradable
(
out
)
and
out
.
requires_grad
:
if
is_autogradable
(
out
)
and
out
.
requires_grad
:
phase
=
Phase
.
LOSS
phase
=
Phase
.
BACKWARD
loss
=
out
.
sum
()
if
isinstance
(
out
,
FlopTensor
):
phase
=
Phase
.
BACKWARD
out
.
_node
.
meta
[
'save'
]
=
False
loss
.
backward
()
grad
=
torch
.
empty_like
(
out
.
_tensor
,
device
=
'meta'
)
if
isinstance
(
out
,
FlopTensor
)
else
torch
.
empty_like
(
out
,
device
=
'meta'
)
torch
.
autograd
.
backward
(
out
,
FlopTensor
(
grad
))
graph_info
=
autograd_graph_analysis
(
subgraph
)
graph_info
=
autograd_graph_analysis
(
subgraph
)
graph_info
.
fwd_flop
,
graph_info
.
bwd_flop
=
flop_count
[
Phase
.
FORWARD
],
flop_count
[
Phase
.
BACKWARD
]
graph_info
.
fwd_flop
,
graph_info
.
bwd_flop
=
flop_count
[
Phase
.
FORWARD
],
flop_count
[
Phase
.
BACKWARD
]
...
@@ -172,7 +175,7 @@ def profile_function(target: 'Target') -> Callable:
...
@@ -172,7 +175,7 @@ def profile_function(target: 'Target') -> Callable:
Examples:
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> func = torch.nn.functional.relu
>>> output, meta_info = profile_function(func)(input
, inplace=False
)
>>> output, meta_info = profile_function(func)(input)
"""
"""
def
f
(
*
args
:
Tuple
[
Argument
,
...],
**
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
def
f
(
*
args
:
Tuple
[
Argument
,
...],
**
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
...
@@ -183,7 +186,7 @@ def profile_function(target: 'Target') -> Callable:
...
@@ -183,7 +186,7 @@ def profile_function(target: 'Target') -> Callable:
args
=
tree_map
(
lambda
x
:
x
.
to
(
'meta'
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
,
args
)
args
=
tree_map
(
lambda
x
:
x
.
to
(
'meta'
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
,
args
)
kwargs
=
tree_map
(
lambda
x
:
x
.
to
(
'meta'
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
,
kwargs
)
kwargs
=
tree_map
(
lambda
x
:
x
.
to
(
'meta'
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
,
kwargs
)
out
=
func
(
*
args
,
**
kwargs
)
out
=
func
(
*
args
,
**
kwargs
)
return
out
,
GraphInfo
(
out
.
numel
(),
out
.
numel
(),
activation_size
((
args
,
kwargs
)),
0
,
activation_size
(
out
)
,
0
)
return
out
,
GraphInfo
(
out
.
numel
(),
out
.
numel
(),
0
,
0
,
0
,
0
)
out
,
meta
=
_profile
(
func
,
*
args
,
**
kwargs
)
out
,
meta
=
_profile
(
func
,
*
args
,
**
kwargs
)
return
out
,
meta
return
out
,
meta
...
@@ -201,7 +204,7 @@ def profile_method(target: 'Target') -> Callable:
...
@@ -201,7 +204,7 @@ def profile_method(target: 'Target') -> Callable:
def
f
(
*
args
:
Tuple
[
Argument
,
...],
**
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
def
f
(
*
args
:
Tuple
[
Argument
,
...],
**
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
# execute the method and return the result
# execute the method and return the result
assert
isinstance
(
target
,
str
),
f
'
{
target
}
instance is not str.'
assert
isinstance
(
target
,
str
),
f
'
{
target
}
instance is not str.'
out
,
meta
=
_profile
(
target
,
*
args
,
inplace
=
False
,
**
kwargs
)
out
,
meta
=
_profile
(
target
,
*
args
,
**
kwargs
)
return
out
,
meta
return
out
,
meta
return
f
return
f
...
@@ -230,8 +233,8 @@ def profile_module(module: torch.nn.Module) -> Callable:
...
@@ -230,8 +233,8 @@ def profile_module(module: torch.nn.Module) -> Callable:
args
=
tree_map
(
lambda
x
:
x
.
to
(
'meta'
),
args
)
args
=
tree_map
(
lambda
x
:
x
.
to
(
'meta'
),
args
)
kwargs
=
tree_map
(
lambda
x
:
x
.
to
(
'meta'
),
kwargs
)
kwargs
=
tree_map
(
lambda
x
:
x
.
to
(
'meta'
),
kwargs
)
out
=
func
(
*
args
,
**
kwargs
)
out
=
func
(
*
args
,
**
kwargs
)
return
out
,
GraphInfo
(
out
.
numel
(),
out
.
numel
(),
activation_size
((
args
,
kwargs
)),
0
,
activation_size
(
out
)
,
0
)
return
out
,
GraphInfo
(
out
.
numel
(),
out
.
numel
(),
0
,
0
,
0
,
0
)
out
,
meta
=
_profile
(
func
,
*
args
,
inplace
=
getattr
(
module
,
'inplace'
,
False
),
**
kwargs
)
out
,
meta
=
_profile
(
func
,
*
args
,
**
kwargs
)
return
out
,
meta
return
out
,
meta
f
.
__name__
=
module
.
__class__
.
__name__
f
.
__name__
=
module
.
__class__
.
__name__
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment