Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
08f2920e
Commit
08f2920e
authored
Apr 23, 2023
by
zhuwenwen
Browse files
init colossalai, support dtk2304
parent
da3f0934
Pipeline
#237
failed with stages
in 0 seconds
Changes
380
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1896 additions
and
0 deletions
+1896
-0
colossalai/fx/profiler/profiler.py
colossalai/fx/profiler/profiler.py
+409
-0
colossalai/fx/profiler/shard_utils.py
colossalai/fx/profiler/shard_utils.py
+114
-0
colossalai/fx/profiler/tensor.py
colossalai/fx/profiler/tensor.py
+140
-0
colossalai/fx/proxy.py
colossalai/fx/proxy.py
+127
-0
colossalai/fx/tracer/__init__.py
colossalai/fx/tracer/__init__.py
+5
-0
colossalai/fx/tracer/_meta_trace.py
colossalai/fx/tracer/_meta_trace.py
+133
-0
colossalai/fx/tracer/_symbolic_trace.py
colossalai/fx/tracer/_symbolic_trace.py
+54
-0
colossalai/fx/tracer/_tracer_utils.py
colossalai/fx/tracer/_tracer_utils.py
+50
-0
colossalai/fx/tracer/bias_addition_patch/__init__.py
colossalai/fx/tracer/bias_addition_patch/__init__.py
+2
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py
...addition_patch/patched_bias_addition_function/__init__.py
+4
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
...s_addition_patch/patched_bias_addition_function/addbmm.py
+75
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
...as_addition_patch/patched_bias_addition_function/addmm.py
+60
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
.../patched_bias_addition_function/bias_addition_function.py
+115
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
...s_addition_patch/patched_bias_addition_function/linear.py
+25
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py
...s_addition_patch/patched_bias_addition_module/__init__.py
+3
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
...atch/patched_bias_addition_module/bias_addition_module.py
+111
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
.../bias_addition_patch/patched_bias_addition_module/conv.py
+56
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
...ias_addition_patch/patched_bias_addition_module/linear.py
+17
-0
colossalai/fx/tracer/experimental.py
colossalai/fx/tracer/experimental.py
+394
-0
colossalai/fx/tracer/meta_patch/__init__.py
colossalai/fx/tracer/meta_patch/__init__.py
+2
-0
No files found.
Too many changes to show.
To preserve performance only
380 of 380+
files are displayed.
Plain diff
Email patch
colossalai/fx/profiler/profiler.py
0 → 100644
View file @
08f2920e
import
time
from
functools
import
partial
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
import
torch
from
torch.fx
import
Graph
,
Node
from
torch.fx.node
import
Argument
,
Target
from
torch.nn.parameter
import
Parameter
from
torch.utils._pytree
import
tree_map
from
.._compatibility
import
compatibility
from
.constants
import
ALIAS_ATEN
,
OUTPUT_SAVED_MOD
,
OUTPUT_SAVED_OPS
from
.dataflow
import
GraphInfo
,
Phase
,
autograd_graph_analysis
,
is_phase
from
.memory_utils
import
activation_size
,
parameter_size
from
.opcount
import
flop_mapping
from
.tensor
import
MetaTensor
__all__
=
[
'profile_function'
,
'profile_module'
,
'profile_method'
]
# super-dainiu: this cache should be global, otherwise it cannot
# track duplicated tensors between nodes
cache
=
set
()
# a global identifier for inplace ops
do_not_cache
=
False
def
normalize_tuple
(
x
):
if
not
isinstance
(
x
,
tuple
):
return
(
x
,)
return
x
def
is_autogradable
(
x
):
return
isinstance
(
x
,
torch
.
Tensor
)
and
x
.
is_floating_point
()
def
detach_variables
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
requires_grad
=
x
.
requires_grad
x
=
x
.
detach
()
x
.
requires_grad
=
requires_grad
return
x
@
compatibility
(
is_backward_compatible
=
True
)
def
_profile_concrete
(
target
:
Callable
,
*
args
,
**
kwargs
)
->
Tuple
[
Tuple
[
Any
,
...],
GraphInfo
]:
"""Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30
To profile the actual forward memory, we first run target in the context torch.no_grad() to get
the fwd_mem_out, then we run target with grad enable to found the extra memory stored in the memory
by memory allocated minus the fwd_mem_out.
To profile the actual backward memory, we first make dummy gradient for torch.autograd.backward, then
find the bwd_mem_tmp with memory peak during the process minus bwd_mem_out(it is actually equal to size
of args and kwargs).
We also add time stamps to profile the real forward and backward time.
Args:
target (Callable): A Callable function
args (Any): Arguments
kwargs (Any): Arguments
Returns:
Tuple[Tuple[Any, ...], GraphInfo]: Output for next node & memory cost and real forward and backward
time.
"""
graphinfo
=
GraphInfo
()
# detach input from the graph
args
=
tree_map
(
detach_variables
,
args
)
kwargs
=
tree_map
(
detach_variables
,
kwargs
)
if
isinstance
(
target
,
str
):
# args[0] is the `self` object for this method call
self_obj
,
*
args_tail
=
args
# calculate fwd_mem_out
mem_stamp0
=
torch
.
cuda
.
memory_allocated
()
with
torch
.
no_grad
():
out
=
getattr
(
self_obj
,
target
)(
*
args_tail
,
**
kwargs
)
mem_stamp1
=
torch
.
cuda
.
memory_allocated
()
graphinfo
.
fwd_mem_out
=
mem_stamp1
-
mem_stamp0
del
out
# calculate fwd_mem_tmp & fwd_time
mem_stamp0
=
torch
.
cuda
.
memory_allocated
()
fwd_time0
=
time
.
time
()
out
=
getattr
(
self_obj
,
target
)(
*
args_tail
,
**
kwargs
)
fwd_time1
=
time
.
time
()
graphinfo
.
fwd_time
=
fwd_time1
-
fwd_time0
mem_stamp1
=
torch
.
cuda
.
memory_allocated
()
graphinfo
.
fwd_mem_tmp
=
mem_stamp1
-
mem_stamp0
-
graphinfo
.
fwd_mem_out
# calculate bwd_mem_tmp & bwd_time
grad_tensors
=
tree_map
(
lambda
x
:
torch
.
ones_like
(
x
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
None
,
out
)
torch
.
cuda
.
reset_peak_memory_stats
()
mem_stamp0
=
torch
.
cuda
.
memory_allocated
()
bwd_time0
=
time
.
time
()
torch
.
autograd
.
backward
(
out
,
grad_tensors
=
grad_tensors
)
bwd_time1
=
time
.
time
()
graphinfo
.
bwd_time
=
bwd_time1
-
bwd_time0
mem_stamp1
=
torch
.
cuda
.
max_memory_allocated
()
# calculate bwd memory stats
# NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation
graphinfo
.
bwd_mem_out
=
activation_size
(
args
)
+
activation_size
(
kwargs
)
graphinfo
.
bwd_mem_out
+=
parameter_size
(
target
.
__self__
)
if
hasattr
(
target
.
__self__
,
"parameters"
)
else
0
graphinfo
.
bwd_mem_tmp
=
mem_stamp1
-
mem_stamp0
-
graphinfo
.
bwd_mem_out
else
:
# calculate fwd_mem_out
mem_stamp0
=
torch
.
cuda
.
memory_allocated
()
with
torch
.
no_grad
():
out
=
target
(
*
args
,
**
kwargs
)
mem_stamp1
=
torch
.
cuda
.
memory_allocated
()
graphinfo
.
fwd_mem_out
=
mem_stamp1
-
mem_stamp0
del
out
# calculate fwd_mem_tmp & fwd_time
mem_stamp0
=
torch
.
cuda
.
memory_allocated
()
fwd_time0
=
time
.
time
()
out
=
target
(
*
args
,
**
kwargs
)
fwd_time1
=
time
.
time
()
graphinfo
.
fwd_time
=
fwd_time1
-
fwd_time0
mem_stamp1
=
torch
.
cuda
.
memory_allocated
()
graphinfo
.
fwd_mem_tmp
=
mem_stamp1
-
mem_stamp0
-
graphinfo
.
fwd_mem_out
# calculate bwd_mem_tmp & bwd_time
grad_tensors
=
tree_map
(
lambda
x
:
torch
.
ones_like
(
x
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
None
,
out
)
torch
.
cuda
.
reset_peak_memory_stats
()
mem_stamp0
=
torch
.
cuda
.
memory_allocated
()
bwd_time0
=
time
.
time
()
torch
.
autograd
.
backward
(
out
,
grad_tensors
=
grad_tensors
)
bwd_time1
=
time
.
time
()
graphinfo
.
bwd_time
=
bwd_time1
-
bwd_time0
mem_stamp1
=
torch
.
cuda
.
max_memory_allocated
()
# calculate bwd memory stats
# NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation
graphinfo
.
bwd_mem_out
=
activation_size
(
args
)
+
activation_size
(
kwargs
)
graphinfo
.
bwd_mem_out
+=
parameter_size
(
target
.
__self__
)
if
hasattr
(
target
.
__self__
,
"parameters"
)
else
0
graphinfo
.
bwd_mem_tmp
=
mem_stamp1
-
mem_stamp0
-
graphinfo
.
bwd_mem_out
return
tree_map
(
detach_variables
,
out
),
graphinfo
@
compatibility
(
is_backward_compatible
=
False
)
def
_profile_meta
(
target
:
Callable
,
*
args
,
**
kwargs
)
->
Tuple
[
Tuple
[
Any
,
...],
GraphInfo
]:
"""
Profile a Callable function with args and kwargs on meta devices.
Args:
target (Callable): A Callable function
args (Any): Argument
kwargs (Any): Argument
Returns:
out (Tuple[Any, ...]): The argument value that was retrieved.
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
# This subgraph traces aten level ops inside one node.
subgraph
=
Graph
()
# `flop_count`` serves as a global dictionary to store results.
flop_count
=
{
Phase
.
FORWARD
:
0
,
Phase
.
BACKWARD
:
0
,
}
# FlopTensor not only get the flop statistics of a single node,
# it also build a full autograd graph for this node.
# This makes sure we can analyze the dependencies of memory, and
# decide which forward intermediate results should be kept until
# backward is executed.
# Hopefully, this attempt will provide a better estimation of memory.
class
FlopTensor
(
MetaTensor
):
_node
:
Node
=
None
def
__repr__
(
self
):
if
self
.
grad_fn
:
return
f
"FlopTensor(
{
self
.
_tensor
}
, fake_device='
{
self
.
device
}
', size=
{
tuple
(
self
.
shape
)
}
, grad_fn=
{
self
.
grad_fn
}
)"
return
f
"FlopTensor(
{
self
.
_tensor
}
, fake_device='
{
self
.
device
}
', size=
{
tuple
(
self
.
shape
)
}
, requires_grad=
{
self
.
requires_grad
}
)"
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
args_node
=
tree_map
(
lambda
x
:
x
.
_node
if
isinstance
(
x
,
FlopTensor
)
else
None
,
args
)
kwargs_node
=
tree_map
(
lambda
x
:
x
.
_node
if
isinstance
(
x
,
FlopTensor
)
else
None
,
kwargs
)
node
=
subgraph
.
create_node
(
'call_function'
,
func
,
args_node
,
kwargs_node
)
out
=
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
flop_count
[
phase
]
+=
flop_mapping
[
func
](
args
,
normalize_tuple
(
out
))
node
.
meta
[
'phase'
]
=
phase
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
# `Phase.FORWARD`
if
phase
==
Phase
.
FORWARD
:
if
all
(
map
(
partial
(
is_phase
,
phase
=
Phase
.
PLACEHOLDER
),
node
.
all_input_nodes
))
and
func
in
ALIAS_ATEN
:
node
.
meta
[
'phase'
]
=
Phase
.
PLACEHOLDER
# TODO(yby): specify `saved_tensors` for backward memory estimation
node
.
meta
[
'saved_tensor'
]
=
[]
if
phase
==
Phase
.
BACKWARD
:
node
.
meta
[
'saved_tensor'
]
=
normalize_tuple
(
out
)
def
wrap
(
x
):
if
isinstance
(
x
,
MetaTensor
):
x
=
FlopTensor
(
x
)
x
.
_node
=
node
return
x
out
=
tree_map
(
wrap
,
out
)
return
out
def
wrap
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
x
=
FlopTensor
(
x
)
if
is_autogradable
(
x
):
x
.
requires_grad_
(
True
)
x
.
_node
=
subgraph
.
create_node
(
'placeholder'
,
'placeholder'
,
(
subgraph
.
_root
,),
name
=
subgraph
.
_graph_namespace
.
create_name
(
'input'
,
x
.
_tensor
))
x
.
_node
.
meta
[
'phase'
]
=
Phase
.
PLACEHOLDER
x
.
_node
.
meta
[
'saved_tensor'
]
=
[]
return
x
# Basically, we need to detach the args and kwargs from the outer graph.
args
=
tree_map
(
wrap
,
args
)
kwargs
=
tree_map
(
wrap
,
kwargs
)
def
pack
(
x
):
global
cache
,
do_not_cache
if
isinstance
(
x
,
FlopTensor
)
and
not
x
.
_tensor
.
data_ptr
()
in
cache
:
tensor
=
x
.
_tensor
.
detach
()
tensor
.
data_ptr
=
x
.
_tensor
.
data_ptr
x
.
_node
.
meta
[
'saved_tensor'
]
+=
[
tensor
]
if
not
do_not_cache
:
cache
.
add
(
x
.
_tensor
.
data_ptr
())
return
x
def
unpack
(
x
):
return
x
# `phase` will mark the phase of autograd from outside scope.
phase
=
Phase
.
FORWARD
# mark saved tensors with saved_tensors_hooks
with
torch
.
autograd
.
graph
.
saved_tensors_hooks
(
pack
,
unpack
):
if
isinstance
(
target
,
str
):
# args[0] is the `self` object for this method call
self_obj
,
*
args_tail
=
args
out
=
getattr
(
self_obj
,
target
)(
*
args_tail
,
**
kwargs
)
else
:
out
=
target
(
*
args
,
**
kwargs
)
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
if
all
(
map
(
lambda
x
:
is_autogradable
(
x
)
and
x
.
requires_grad
,
normalize_tuple
(
out
))):
grad_out
=
[
torch
.
zeros_like
(
t
)
for
t
in
normalize_tuple
(
out
)]
phase
=
Phase
.
BACKWARD
torch
.
autograd
.
backward
(
out
,
grad_out
,
)
graph_info
=
autograd_graph_analysis
(
subgraph
)
graph_info
.
fwd_flop
,
graph_info
.
bwd_flop
=
flop_count
[
Phase
.
FORWARD
],
flop_count
[
Phase
.
BACKWARD
]
def
extract_tensor
(
x
:
Any
):
if
isinstance
(
x
,
MetaTensor
):
tensor
=
x
.
_tensor
.
detach
()
tensor
.
data_ptr
=
x
.
_tensor
.
data_ptr
return
tensor
if
not
isinstance
(
x
,
torch
.
finfo
):
return
x
graph_info
.
fwd_out
=
list
(
map
(
extract_tensor
,
normalize_tuple
(
out
)))
def
unwrap
(
x
):
return
MetaTensor
(
x
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
return
tree_map
(
unwrap
,
out
),
graph_info
@
compatibility
(
is_backward_compatible
=
True
)
def
profile_function
(
target
:
'Target'
,
device
:
str
=
'meta'
)
->
Callable
:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, meta_info = profile_function(func)(input)
"""
def
f
(
*
args
:
Tuple
[
Argument
,
...],
**
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
# find the grad for parameter in args and kwargs
param_size
=
0
def
get_param_size
(
x
):
nonlocal
param_size
if
isinstance
(
x
,
Parameter
):
param_size
+=
activation_size
(
x
)
tree_map
(
get_param_size
,
args
)
tree_map
(
get_param_size
,
kwargs
)
# If there is an argument that this `call_function` is inplace, we should
# still run the profiling but discard some results regarding `target`
global
do_not_cache
inplace
=
kwargs
.
get
(
'inplace'
,
False
)
if
target
in
OUTPUT_SAVED_OPS
:
do_not_cache
=
True
if
inplace
:
do_not_cache
=
True
kwargs
[
'inplace'
]
=
False
if
device
==
'meta'
:
out
,
meta
=
_profile_meta
(
func
,
*
args
,
**
kwargs
)
else
:
out
,
meta
=
_profile_concrete
(
func
,
*
args
,
**
kwargs
)
if
inplace
:
kwargs
[
'inplace'
]
=
True
meta
.
bwd_mem_tmp
=
0
meta
.
bwd_mem_out
=
0
do_not_cache
=
False
meta
.
bwd_mem_out
-=
param_size
return
out
,
meta
f
.
__name__
=
target
.
__name__
func
=
target
return
f
@
compatibility
(
is_backward_compatible
=
True
)
def
profile_method
(
target
:
'Target'
,
device
:
str
=
'meta'
)
->
Callable
:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
"""
def
f
(
*
args
:
Tuple
[
Argument
,
...],
**
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
# execute the method and return the result
assert
isinstance
(
target
,
str
),
f
'
{
target
}
instance is not str.'
if
device
==
'meta'
:
out
,
meta
=
_profile_meta
(
target
,
*
args
,
**
kwargs
)
else
:
out
,
meta
=
_profile_concrete
(
target
,
*
args
,
**
kwargs
)
return
out
,
meta
return
f
@
compatibility
(
is_backward_compatible
=
True
)
def
profile_module
(
module
:
torch
.
nn
.
Module
,
device
:
str
=
'meta'
)
->
Callable
:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
>>> output, meta_info = profile_module(mod)(input)
"""
def
f
(
*
args
:
Tuple
[
Argument
,
...],
**
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
# calculate parameter size
param_size
=
parameter_size
(
module
)
# If there is an argument that this `call_module` is inplace, we should
# still run the profiling but discard some results regarding `module`.
global
do_not_cache
inplace
=
getattr
(
module
,
'inplace'
,
False
)
if
type
(
module
)
in
OUTPUT_SAVED_MOD
:
do_not_cache
=
True
if
inplace
:
do_not_cache
=
True
module
.
inplace
=
False
if
device
==
'meta'
:
out
,
meta
=
_profile_meta
(
func
,
*
args
,
**
kwargs
)
else
:
out
,
meta
=
_profile_concrete
(
func
,
*
args
,
**
kwargs
)
if
inplace
:
module
.
inplace
=
True
meta
.
bwd_mem_tmp
=
0
meta
.
bwd_mem_out
=
0
do_not_cache
=
False
# grad for param will not be counted
meta
.
bwd_mem_out
-=
param_size
return
out
,
meta
f
.
__name__
=
module
.
__class__
.
__name__
func
=
module
.
forward
return
f
colossalai/fx/profiler/shard_utils.py
0 → 100644
View file @
08f2920e
import
torch
from
torch.fx
import
Node
from
.._compatibility
import
compatibility
,
is_compatible_with_meta
from
.memory_utils
import
activation_size
if
is_compatible_with_meta
():
from
.constants
import
OUTPUT_SAVED_MOD
,
OUTPUT_SAVED_OPS
__all__
=
[
"calculate_fwd_in"
,
"calculate_fwd_tmp"
,
"calculate_fwd_out"
]
@
compatibility
(
is_backward_compatible
=
False
)
def
calculate_fwd_in
(
n
:
Node
)
->
int
:
"""A helper function to calculate `fwd_in` (with sharding spec)
Args:
n (Node): a node from the graph
Returns:
fwd_in (int): the result of `fwd_in`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
return
activation_size
(
n
.
meta
[
"fwd_in"
])
@
compatibility
(
is_backward_compatible
=
False
)
def
calculate_fwd_tmp
(
n
:
Node
)
->
int
:
"""A helper function to calculate `fwd_tmp` (with sharding spec)
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
Args:
n (Node): a node from the graph
Returns:
fwd_tmp (int): the result of `fwd_tmp`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
def
is_relu_like_node
(
n
:
Node
)
->
bool
:
"""Check if a node is a ReLU-like node.
ReLU-like nodes have the following properties:
- They are either `call_function` or `call_module`
- Their output tensors are directly saved for backward
- Their input tensors are not saved for backward
An example is `torch.nn.functional.softmax` which has (forward + backward):
def forward(self, input_2):
_softmax_default = torch.ops.aten._softmax.default(input_2, None, None); input_2 = None
zeros_like_default = torch.ops.aten.zeros_like.default(_softmax_default, dtype = None, layout = None, device = None, pin_memory = None)
detach_default = torch.ops.aten.detach.default(_softmax_default); _softmax_default = None
_softmax_backward_data_default = torch.ops.aten._softmax_backward_data.default(zeros_like_default, detach_default, None, None); zeros_like_default = detach_default = None
detach_default_1 = torch.ops.aten.detach.default(_softmax_backward_data_default); _softmax_backward_data_default = None
detach_default_2 = torch.ops.aten.detach.default(detach_default_1); detach_default_1 = None
Args:
n (Node): A node from the graph
Returns:
bool: Whether the node is a ReLU-like node
"""
if
n
.
op
==
'call_function'
:
return
n
.
target
in
OUTPUT_SAVED_OPS
elif
n
.
op
==
'call_module'
:
return
type
(
n
.
graph
.
owning_module
.
get_submodule
(
n
.
target
))
in
OUTPUT_SAVED_MOD
return
False
if
not
is_relu_like_node
(
n
):
return
activation_size
(
n
.
meta
[
"fwd_tmp"
])
return
0
@
compatibility
(
is_backward_compatible
=
False
)
def
calculate_fwd_out
(
n
:
Node
)
->
int
:
"""A helper function to calculate `fwd_out` (with sharding spec)
Args:
n (Node): a node from the graph
Returns:
fwd_out (int): the result of `fwd_out`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
def
intersect
(
a
,
b
):
return
{
k
:
a
[
k
]
for
k
in
a
if
k
in
b
}
fwd_in
=
dict
()
for
u
in
n
.
users
:
fwd_in
.
update
({
x
.
data_ptr
():
x
for
x
in
u
.
meta
[
"fwd_in"
]
if
isinstance
(
x
,
torch
.
Tensor
)})
fwd_out
=
{
x
.
data_ptr
():
x
for
x
in
n
.
meta
[
"fwd_out"
]
if
isinstance
(
x
,
torch
.
Tensor
)}
return
activation_size
(
intersect
(
fwd_in
,
fwd_out
))
def
calculate_fwd_time
(
n
:
Node
)
->
float
:
"""A helper function to calculate `fwd_time` (with sharding spec)
Args:
n (Node): a node from the graph
Returns:
fwd_time (float): the result of `fwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return
n
.
meta
[
"fwd_flop"
]
def
calculate_bwd_time
(
n
:
Node
)
->
float
:
"""A helper function to calculate `bwd_time` (with sharding spec)
Args:
n (Node): a node from the graph
Returns:
bwd_time (float): the result of `bwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return
n
.
meta
[
"bwd_flop"
]
colossalai/fx/profiler/tensor.py
0 → 100644
View file @
08f2920e
import
uuid
from
copy
import
deepcopy
from
typing
import
Optional
import
torch
from
torch.types
import
_bool
,
_device
,
_dtype
from
torch.utils._pytree
import
tree_flatten
,
tree_map
from
.._compatibility
import
compatibility
from
.constants
import
ALIAS_ATEN
__all__
=
[
'MetaTensor'
]
def
set_data_ptr
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
if
not
x
.
data_ptr
():
data_ptr
=
uuid
.
uuid4
()
x
.
data_ptr
=
lambda
:
data_ptr
@
compatibility
(
is_backward_compatible
=
False
)
class
MetaTensor
(
torch
.
Tensor
):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
`fake_device` is the device that `MetaTensor` is supposed to run on.
"""
_tensor
:
torch
.
Tensor
__slots__
=
[
'_tensor'
]
@
staticmethod
def
__new__
(
cls
,
elem
,
fake_device
=
None
):
# Avoid multiple wrapping
if
isinstance
(
elem
,
MetaTensor
):
fake_device
=
elem
.
device
if
fake_device
is
None
else
fake_device
elem
=
elem
.
_tensor
# The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
r
=
torch
.
Tensor
.
_make_wrapper_subclass
(
cls
,
elem
.
size
(),
strides
=
elem
.
stride
(),
storage_offset
=
elem
.
storage_offset
(),
dtype
=
elem
.
dtype
,
layout
=
elem
.
layout
,
device
=
fake_device
if
fake_device
is
not
None
else
elem
.
device
,
requires_grad
=
elem
.
requires_grad
)
# deceive the frontend for aten selections
r
.
_tensor
=
elem
# ...the real tensor is held as an element on the tensor.
if
not
r
.
_tensor
.
is_meta
:
r
.
_tensor
=
r
.
_tensor
.
to
(
torch
.
device
(
'meta'
))
# only tensor not on `meta` should be copied to `meta`
set_data_ptr
(
r
.
_tensor
)
return
r
def
__repr__
(
self
):
if
self
.
grad_fn
:
return
f
"MetaTensor(
{
self
.
_tensor
}
, fake_device='
{
self
.
device
}
', grad_fn=
{
self
.
grad_fn
}
)"
return
f
"MetaTensor(
{
self
.
_tensor
}
, fake_device='
{
self
.
device
}
')"
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
fake_device
=
None
def
unwrap
(
x
):
nonlocal
fake_device
if
isinstance
(
x
,
MetaTensor
):
fake_device
=
x
.
device
x
=
x
.
_tensor
elif
isinstance
(
x
,
torch
.
Tensor
):
fake_device
=
x
.
device
x
=
x
.
to
(
torch
.
device
(
'meta'
))
return
x
if
'device'
in
kwargs
:
fake_device
=
kwargs
[
'device'
]
kwargs
[
'device'
]
=
torch
.
device
(
'meta'
)
args
=
tree_map
(
unwrap
,
args
)
kwargs
=
tree_map
(
unwrap
,
kwargs
)
# run aten for backend=CPU but actually on backend=Meta
out
=
func
(
*
args
,
**
kwargs
)
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
# of the input
if
func
in
ALIAS_ATEN
:
out
.
data_ptr
=
args
[
0
].
data_ptr
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def
wrap
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
nonlocal
fake_device
if
not
x
.
is_meta
:
x
=
x
.
to
(
torch
.
device
(
'meta'
))
return
MetaTensor
(
x
,
fake_device
=
fake_device
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
return
tree_map
(
wrap
,
out
)
def
to
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""An extension of `torch.Tensor.to()` to MetaTensor
Returns:
result (MetaTensor): MetaTensor
Usage:
>>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100')
>>> tensor.to(torch.uint8)
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100')
>>> tensor.to(torch.device('cuda:42'))
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42')
>>> tensor.to('vulkan')
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
"""
# this imitates c++ function in the way of @overload
device
=
None
for
arg
in
args
:
if
isinstance
(
arg
,
str
)
or
isinstance
(
arg
,
_device
):
device
=
arg
if
'device'
in
kwargs
:
device
=
kwargs
[
'device'
]
result
=
super
().
to
(
*
args
,
**
kwargs
)
if
device
is
not
None
:
result
=
MetaTensor
(
result
,
fake_device
=
device
)
return
result
def
cpu
(
self
,
*
args
,
**
kwargs
):
if
self
.
device
.
type
==
'cpu'
:
return
self
.
to
(
*
args
,
**
kwargs
)
return
self
.
to
(
*
args
,
device
=
'cpu'
,
**
kwargs
)
def
cuda
(
self
,
*
args
,
**
kwargs
):
if
self
.
device
.
type
==
'cuda'
:
return
self
.
to
(
*
args
,
**
kwargs
)
return
self
.
to
(
*
args
,
device
=
'cuda'
,
**
kwargs
)
colossalai/fx/proxy.py
0 → 100644
View file @
08f2920e
import
operator
import
torch
from
torch.fx.proxy
import
Proxy
,
Attribute
from
typing
import
List
,
Union
,
Any
from
colossalai.fx.tracer.meta_patch
import
meta_patched_function
__all__
=
[
'ColoProxy'
]
class
ColoProxy
(
Proxy
):
"""
ColoProxy is a proxy class which uses meta tensor to handle data-dependent control flow. The original torch.fx proxy
cannot be used to infer the condition statement, with this proxy, torch.fx can still run even with if statements.
Example::
proxy = tracer.create_proxy(...)
proxy.meta_data = torch.empty(4, 2, device='meta')
print(len(proxy)) # expect output 4
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
node
.
_meta_data
=
None
@
property
def
meta_data
(
self
):
return
self
.
node
.
_meta_data
@
meta_data
.
setter
def
meta_data
(
self
,
data
:
Any
):
self
.
node
.
_meta_data
=
data
@
property
def
has_meta_data
(
self
):
return
self
.
_meta_data
is
not
None
def
_assert_meta_data_is_tensor
(
self
):
assert
torch
.
is_tensor
(
self
.
_meta_data
)
and
self
.
_meta_data
.
is_meta
,
f
'Meta data is not a meta tensor for
{
self
.
node
.
name
}
'
def
_assert_has_meta_data
(
self
):
assert
self
.
_meta_data
is
not
None
,
f
'Meta data is not set for
{
self
.
node
.
name
}
'
def
__len__
(
self
):
self
.
_assert_has_meta_data
()
return
len
(
self
.
meta_data
)
def
__int__
(
self
):
self
.
_assert_has_meta_data
()
return
int
(
self
.
meta_data
)
def
__float__
(
self
):
self
.
_assert_has_meta_data
()
return
float
(
self
.
meta_data
)
def
__bool__
(
self
):
self
.
_assert_has_meta_data
()
return
self
.
meta_data
def
__getattr__
(
self
,
k
):
return
ColoAttribute
(
self
,
k
)
def
__contains__
(
self
,
key
):
if
self
.
node
.
op
==
"placeholder"
:
# this is used to handle like
# if x in kwargs
# we don't handle this case for now
return
False
return
super
().
__contains__
(
key
)
def
extract_meta
(
*
args
,
**
kwargs
):
"""
This function is copied from _tracer_utils.py to avoid circular import issue.
"""
def
_convert
(
val
):
if
isinstance
(
val
,
ColoProxy
):
return
val
.
meta_data
elif
isinstance
(
val
,
(
list
,
tuple
)):
return
type
(
val
)([
_convert
(
ele
)
for
ele
in
val
])
return
val
new_args
=
[
_convert
(
val
)
for
val
in
args
]
new_kwargs
=
{
k
:
_convert
(
v
)
for
k
,
v
in
kwargs
.
items
()}
return
new_args
,
new_kwargs
class
ColoAttribute
(
ColoProxy
):
def
__init__
(
self
,
root
,
attr
:
str
):
self
.
root
=
root
self
.
attr
=
attr
self
.
tracer
=
root
.
tracer
self
.
_node
=
None
@
property
def
node
(
self
):
if
self
.
_node
is
None
:
proxy
=
self
.
tracer
.
create_proxy
(
"call_function"
,
getattr
,
(
self
.
root
,
self
.
attr
),
{})
if
not
isinstance
(
proxy
,
ColoProxy
):
meta_args
,
meta_kwargs
=
extract_meta
(
*
(
self
.
root
,
self
.
attr
))
meta_out
=
getattr
(
*
meta_args
,
**
meta_kwargs
)
proxy
=
ColoProxy
(
proxy
.
node
)
proxy
.
meta_data
=
meta_out
self
.
_node
=
proxy
.
node
return
self
.
_node
def
__call__
(
self
,
*
args
,
**
kwargs
):
proxy
=
self
.
tracer
.
create_proxy
(
"call_method"
,
self
.
attr
,
(
self
.
root
,)
+
args
,
kwargs
)
if
not
isinstance
(
proxy
,
ColoProxy
):
meta_args
,
meta_kwargs
=
extract_meta
(
*
((
self
.
root
,)
+
args
),
**
kwargs
)
method
=
getattr
(
meta_args
[
0
].
__class__
,
self
.
attr
)
if
meta_patched_function
.
has
(
method
):
meta_target
=
meta_patched_function
.
get
(
method
)
elif
meta_patched_function
.
has
(
method
.
__name__
):
meta_target
=
meta_patched_function
.
get
(
method
.
__name__
)
else
:
meta_target
=
method
meta_out
=
meta_target
(
*
meta_args
,
**
meta_kwargs
)
proxy
=
ColoProxy
(
proxy
.
node
)
proxy
.
meta_data
=
meta_out
return
proxy
colossalai/fx/tracer/__init__.py
0 → 100644
View file @
08f2920e
from
colossalai.fx.tracer.meta_patch.patched_function.python_ops
import
operator_getitem
from
._meta_trace
import
meta_trace
from
._symbolic_trace
import
symbolic_trace
from
.tracer
import
ColoTracer
colossalai/fx/tracer/_meta_trace.py
0 → 100644
View file @
08f2920e
import
torch
from
torch.fx
import
Graph
,
Node
from
torch.utils._pytree
import
tree_map
def
normalize_tuple
(
x
):
if
not
isinstance
(
x
,
tuple
):
return
(
x
,)
return
x
def
is_autogradable
(
x
):
return
isinstance
(
x
,
torch
.
Tensor
)
and
x
.
is_floating_point
()
def
meta_trace
(
module
:
torch
.
nn
.
Module
,
fake_device
=
None
,
*
args
,
**
kwargs
)
->
Graph
:
"""Trace forward and backward graph with MetaTensor
Args:
module (torch.nn.Module): The target module for tracing.
Returns:
graph (torch.fx.Graph): The computation graph.
Usage:
>>> import torchvision.models as tm
>>> model = tm.alexnet()
>>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224))
>>> graph.print_tabular()
"""
graph
=
Graph
()
namespace
=
graph
.
_graph_namespace
class
MetaProxy
(
torch
.
Tensor
):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
"""
_tensor
:
torch
.
Tensor
_node
:
Node
__slots__
=
[
'_tensor'
,
'_node'
]
@
staticmethod
def
__new__
(
cls
,
tensor
,
fake_device
=
None
,
placeholder
=
False
,
name
=
None
):
r
=
torch
.
Tensor
.
_make_wrapper_subclass
(
cls
,
tensor
.
size
(),
strides
=
tensor
.
stride
(),
storage_offset
=
tensor
.
storage_offset
(),
dtype
=
tensor
.
dtype
,
layout
=
tensor
.
layout
,
device
=
fake_device
if
fake_device
is
not
None
else
tensor
.
device
,
requires_grad
=
tensor
.
requires_grad
)
# deceive the frontend for aten selections
r
.
_tensor
=
tensor
if
placeholder
:
if
name
is
None
:
name
=
'input'
r
.
_node
=
graph
.
create_node
(
'placeholder'
,
'placeholder'
,
(
graph
.
_root
,),
name
=
namespace
.
create_name
(
name
,
tensor
))
# ...the real tensor is held as an element on the tensor.
if
not
r
.
_tensor
.
is_meta
:
r
.
_tensor
=
r
.
_tensor
.
to
(
torch
.
device
(
'meta'
))
return
r
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
def
unwrap
(
x
):
nonlocal
fake_device
if
isinstance
(
x
,
MetaProxy
):
fake_device
=
x
.
device
x
=
x
.
_tensor
# assert not isinstance(x, MetaProxy)
elif
isinstance
(
x
,
torch
.
Tensor
):
fake_device
=
x
.
device
x
=
x
.
to
(
torch
.
device
(
'meta'
))
return
x
def
get_node
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
)
and
not
hasattr
(
x
,
'_node'
):
x
=
MetaProxy
(
x
,
placeholder
=
True
,
name
=
'weight'
)
return
x
if
not
hasattr
(
x
,
'_node'
)
else
x
.
_node
args_node
=
tree_map
(
get_node
,
args
)
kwargs_node
=
tree_map
(
get_node
,
kwargs
)
node
=
graph
.
create_node
(
'call_function'
,
func
,
args_node
,
kwargs_node
)
if
'device'
in
kwargs
:
fake_device
=
kwargs
[
'device'
]
kwargs
[
'device'
]
=
torch
.
device
(
'meta'
)
args
=
tree_map
(
unwrap
,
args
)
kwargs
=
tree_map
(
unwrap
,
kwargs
)
# run aten for backend=CPU but actually on backend=Meta
out
=
func
(
*
args
,
**
kwargs
)
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def
wrap
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
nonlocal
fake_device
if
not
x
.
is_meta
:
x
=
x
.
to
(
torch
.
device
(
'meta'
))
return
MetaProxy
(
x
,
fake_device
=
fake_device
)
if
isinstance
(
x
,
torch
.
Tensor
)
and
not
hasattr
(
x
,
'_tensor'
)
else
x
def
set_node
(
x
):
x
.
_node
=
node
out
=
tree_map
(
wrap
,
out
)
tree_map
(
set_node
,
out
)
return
out
def
wrap
(
x
):
return
MetaProxy
(
x
,
fake_device
=
fake_device
,
placeholder
=
True
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
args
=
tree_map
(
wrap
,
args
)
kwargs
=
tree_map
(
wrap
,
kwargs
)
out
=
module
(
*
args
,
**
kwargs
)
for
tensor
in
normalize_tuple
(
out
):
if
is_autogradable
(
tensor
)
and
tensor
.
requires_grad
:
grad
=
torch
.
empty_like
(
tensor
.
_tensor
,
device
=
torch
.
device
(
'meta'
))
if
isinstance
(
tensor
,
MetaProxy
)
else
torch
.
empty_like
(
tensor
,
device
=
torch
.
device
(
'meta'
))
torch
.
autograd
.
backward
(
tensor
,
MetaProxy
(
grad
,
fake_device
=
tensor
.
device
,
placeholder
=
True
),
retain_graph
=
True
)
return
graph
colossalai/fx/tracer/_symbolic_trace.py
0 → 100644
View file @
08f2920e
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Union
import
torch
from
colossalai.fx
import
ColoGraphModule
from
colossalai.fx._compatibility
import
compatibility
from
.tracer
import
ColoTracer
@
compatibility
(
is_backward_compatible
=
True
)
def
symbolic_trace
(
root
:
Union
[
torch
.
nn
.
Module
,
Callable
[...,
Any
]],
concrete_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
meta_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
ColoGraphModule
:
"""
Symbolic tracing API
Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule``
constructed by recording operations seen while tracing through ``root``.
With ``meta_args``, we can trace the model that are untraceable subject to control flow. If specified using
``meta_args`` only, the tracing can be done ahead of time.
Note that ``meta_args`` are kwargs, which contains the key of the argument's names and the value of the
argument's values.
Uses:
>>> model = ...
# if this works
>>> gm = symbolic_trace(model, concrete_args=concrete_args)
# else try this
>>> gm = symbolic_trace(model, concrete_args=concrete_args, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted
into a Graph representation.
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be used for tracing.
meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``.
Defaults to None.
Returns:
ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``.
Warnings:
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
graph
=
ColoTracer
().
trace
(
root
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
name
=
root
.
__class__
.
__name__
if
isinstance
(
root
,
torch
.
nn
.
Module
)
else
root
.
__name__
return
ColoGraphModule
(
root
,
graph
,
name
)
colossalai/fx/tracer/_tracer_utils.py
0 → 100644
View file @
08f2920e
from
typing
import
List
,
Union
,
Any
from
..proxy
import
ColoProxy
,
ColoAttribute
import
torch
from
.meta_patch
import
meta_patched_function
,
meta_patched_module
__all__
=
[
'is_element_in_list'
,
'extract_meta'
]
def
is_element_in_list
(
elements
:
Union
[
List
[
Any
],
Any
],
list_
:
List
[
Any
]):
if
isinstance
(
elements
,
(
tuple
,
list
,
set
)):
for
ele
in
elements
:
if
ele
not
in
list_
:
return
False
,
ele
else
:
if
elements
not
in
list_
:
return
False
,
elements
return
True
,
None
def
extract_meta
(
*
args
,
**
kwargs
):
def
_convert
(
val
):
if
isinstance
(
val
,
ColoProxy
):
return
val
.
meta_data
elif
isinstance
(
val
,
(
list
,
tuple
)):
return
type
(
val
)([
_convert
(
ele
)
for
ele
in
val
])
return
val
new_args
=
[
_convert
(
val
)
for
val
in
args
]
new_kwargs
=
{
k
:
_convert
(
v
)
for
k
,
v
in
kwargs
.
items
()}
return
new_args
,
new_kwargs
def
compute_meta_data_for_functions_proxy
(
target
,
args
,
kwargs
):
args_metas
,
kwargs_metas
=
extract_meta
(
*
args
,
**
kwargs
)
# fetch patched function
if
meta_patched_function
.
has
(
target
):
meta_target
=
meta_patched_function
.
get
(
target
)
elif
meta_patched_function
.
has
(
target
.
__name__
):
meta_target
=
meta_patched_function
.
get
(
target
.
__name__
)
else
:
meta_target
=
target
meta_out
=
meta_target
(
*
args_metas
,
**
kwargs_metas
)
if
isinstance
(
meta_out
,
torch
.
Tensor
):
meta_out
=
meta_out
.
to
(
device
=
"meta"
)
return
meta_out
colossalai/fx/tracer/bias_addition_patch/__init__.py
0 → 100644
View file @
08f2920e
from
.patched_bias_addition_function
import
*
from
.patched_bias_addition_module
import
*
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py
0 → 100644
View file @
08f2920e
from
.addbmm
import
Addbmm
from
.addmm
import
Addmm
from
.bias_addition_function
import
BiasAdditionFunc
,
LinearBasedBiasFunc
,
func_to_func_dict
,
method_to_func_dict
from
.linear
import
Linear
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
0 → 100644
View file @
08f2920e
import
operator
import
torch
import
torch.nn.functional
as
F
from
...registry
import
bias_addition_function
,
bias_addition_method
from
.bias_addition_function
import
LinearBasedBiasFunc
@
bias_addition_method
.
register
(
torch
.
Tensor
.
addbmm
)
@
bias_addition_function
.
register
(
torch
.
addbmm
)
class
Addbmm
(
LinearBasedBiasFunc
):
def
extract_kwargs_from_origin_func
(
self
):
kwargs
=
{}
if
'beta'
in
self
.
kwargs
:
kwargs
[
'beta'
]
=
self
.
kwargs
[
'beta'
]
if
'alpha'
in
self
.
kwargs
:
kwargs
[
'alpha'
]
=
self
.
kwargs
[
'alpha'
]
return
kwargs
def
create_non_bias_func_proxy
(
self
,
input_proxy
,
other_proxy
):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
assert
self
.
substitute_func
==
torch
.
bmm
node_kind
=
'call_function'
node_target
=
self
.
substitute_func
node_args
=
(
input_proxy
,
other_proxy
)
# torch.bmm does not have any kwargs
node_kwargs
=
{}
non_bias_func_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
non_bias_func_proxy
def
insert_sum_node
(
self
,
input_proxy
,
sum_dims
=
0
):
'''
This method is used to sum the input_proxy through the sum_dims.
'''
node_kind
=
'call_function'
node_target
=
torch
.
sum
node_args
=
(
input_proxy
,
sum_dims
)
node_kwargs
=
{}
sum_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
sum_proxy
def
generate
(
self
):
# The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2))
# doing the non-bias computation(temp_0 = torch.bmm(b1, b2))
non_bias_linear_func_proxy
=
self
.
create_non_bias_func_proxy
(
self
.
args
[
1
],
self
.
args
[
2
])
# doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0))
sum_proxy
=
self
.
insert_sum_node
(
non_bias_linear_func_proxy
)
kwargs
=
self
.
extract_kwargs_from_origin_func
()
if
'beta'
in
kwargs
:
beta
=
kwargs
[
'beta'
]
# doing the multiplication with beta if it exists(temp_2 = beta * input)
beta_proxy
=
self
.
create_mul_node
(
self
.
args
[
0
],
beta
)
else
:
beta_proxy
=
self
.
args
[
0
]
if
'alpha'
in
kwargs
:
alpha
=
kwargs
[
'alpha'
]
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
alpha_proxy
=
self
.
create_mul_node
(
alpha
,
sum_proxy
)
else
:
alpha_proxy
=
sum_proxy
# doing the addition(temp_4 = temp_2 + temp_3)
bias_addition_proxy
=
self
.
create_bias_addition_proxy
(
alpha_proxy
,
beta_proxy
)
return
bias_addition_proxy
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
0 → 100644
View file @
08f2920e
import
operator
import
torch
import
torch.nn.functional
as
F
from
...registry
import
bias_addition_function
,
bias_addition_method
from
.bias_addition_function
import
LinearBasedBiasFunc
@
bias_addition_method
.
register
(
torch
.
Tensor
.
addmm
)
@
bias_addition_function
.
register
(
torch
.
addmm
)
class
Addmm
(
LinearBasedBiasFunc
):
def
extract_kwargs_from_origin_func
(
self
):
kwargs
=
{}
if
'beta'
in
self
.
kwargs
:
kwargs
[
'beta'
]
=
self
.
kwargs
[
'beta'
]
if
'alpha'
in
self
.
kwargs
:
kwargs
[
'alpha'
]
=
self
.
kwargs
[
'alpha'
]
return
kwargs
def
transpose_other_operand_for_linear
(
self
,
other_proxy
):
'''
This method is used to transpose the other operand for linear function.
For example:
input = torch.rand(3, 4)
m1 = torch.rand(3, 5)
m2 = torch.rand(5, 4)
original_output = torch.addmm(input, m1, m2)
# To keep the computation graph consistent with the origin computation graph, we need to transpose the m2
# before we call the linear function.
new_output = torch.linear(m1, m2.transpose(0, 1)) + input
'''
node_kind
=
'call_function'
node_target
=
torch
.
transpose
node_args
=
(
other_proxy
,
0
,
1
)
node_kwargs
=
{}
transpose_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
transpose_proxy
def
generate
(
self
):
transpose_proxy
=
self
.
transpose_other_operand_for_linear
(
self
.
args
[
2
])
non_bias_linear_func_proxy
=
self
.
create_non_bias_func_proxy
(
self
.
args
[
1
],
transpose_proxy
)
kwargs
=
self
.
extract_kwargs_from_origin_func
()
if
'beta'
in
kwargs
:
beta
=
kwargs
[
'beta'
]
beta_proxy
=
self
.
create_mul_node
(
self
.
args
[
0
],
beta
)
else
:
beta_proxy
=
self
.
args
[
0
]
if
'alpha'
in
kwargs
:
alpha
=
kwargs
[
'alpha'
]
alpha_proxy
=
self
.
create_mul_node
(
alpha
,
non_bias_linear_func_proxy
)
else
:
alpha_proxy
=
non_bias_linear_func_proxy
bias_addition_proxy
=
self
.
create_bias_addition_proxy
(
alpha_proxy
,
beta_proxy
)
return
bias_addition_proxy
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
0 → 100644
View file @
08f2920e
import
operator
from
abc
import
ABC
,
abstractmethod
import
torch
import
torch.nn.functional
as
F
class
BiasAdditionFunc
(
ABC
):
"""
This class is used to construct the restructure computation graph for
call_func node with bias addition inside.
"""
def
__init__
(
self
,
tracer
,
target
,
args
,
kwargs
,
substitute_func
):
self
.
tracer
=
tracer
self
.
target
=
target
self
.
args
=
args
self
.
kwargs
=
kwargs
self
.
substitute_func
=
substitute_func
@
abstractmethod
def
extract_kwargs_from_origin_func
(
self
):
"""
This method is used to extract the kwargs for further graph transform.
For example:
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
The kwargs for addmm function is {beta=1, alpha=1, output=None}, then we need
to insert two more operator.mul nodes for the computation graph to compute the
final result.
"""
pass
@
abstractmethod
def
generate
(
self
):
"""
This method is used to construct the whole restructure computation graph for call_func node with bias
addition inside.
A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node,
a bias reshape node if needed and a bias addition node.
Use torch.addmm as an example:
The origin node is:
%addmm: call_func[target=torch.addmm](args = (%input_1, m1, m2), kwargs = {beta=1, alpha=1})
Restructured graph is:
%transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {})
%linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {})
%mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {})
%mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
"""
pass
def
create_mul_node
(
self
,
input_proxy
,
coefficent
):
"""
This method is used to create a coefficent node for the numerical correctness.
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
node_kind
=
'call_function'
node_target
=
operator
.
mul
node_args
=
(
input_proxy
,
coefficent
,
)
node_kwargs
=
{}
mul_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
mul_proxy
class
LinearBasedBiasFunc
(
BiasAdditionFunc
):
"""
This class is used to construct the restructure computation graph for
call_func node based on F.linear.
"""
def
create_non_bias_func_proxy
(
self
,
input_proxy
,
other_proxy
):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
assert
self
.
substitute_func
==
torch
.
nn
.
functional
.
linear
node_kind
=
'call_function'
node_target
=
self
.
substitute_func
node_args
=
(
input_proxy
,
other_proxy
)
# non-bias linear does not have any kwargs
node_kwargs
=
{}
non_bias_func_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
non_bias_func_proxy
def
create_bias_addition_proxy
(
self
,
non_bias_func_proxy
,
bias_proxy
):
"""
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
bias_add_node_kind
=
'call_function'
bias_add_node_target
=
operator
.
add
bias_add_args
=
(
non_bias_func_proxy
,
bias_proxy
)
bias_add_proxy
=
self
.
tracer
.
create_proxy
(
bias_add_node_kind
,
bias_add_node_target
,
tuple
(
bias_add_args
),
{})
return
bias_add_proxy
func_to_func_dict
=
{
torch
.
addmm
:
F
.
linear
,
torch
.
addbmm
:
torch
.
bmm
,
F
.
linear
:
F
.
linear
,
}
method_to_func_dict
=
{
torch
.
Tensor
.
addmm
:
F
.
linear
,
torch
.
Tensor
.
addbmm
:
torch
.
bmm
,
}
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
0 → 100644
View file @
08f2920e
import
operator
import
torch
import
torch.nn.functional
as
F
from
...registry
import
bias_addition_function
from
.bias_addition_function
import
LinearBasedBiasFunc
@
bias_addition_function
.
register
(
F
.
linear
)
class
Linear
(
LinearBasedBiasFunc
):
def
extract_kwargs_from_origin_func
(
self
):
assert
'bias'
in
self
.
kwargs
kwargs
=
{}
if
'bias'
in
self
.
kwargs
:
kwargs
[
'bias'
]
=
self
.
kwargs
[
'bias'
]
return
kwargs
def
generate
(
self
):
non_bias_linear_func_proxy
=
self
.
create_non_bias_func_proxy
(
self
.
args
[
0
],
self
.
args
[
1
])
kwargs
=
self
.
extract_kwargs_from_origin_func
()
bias_addition_proxy
=
self
.
create_bias_addition_proxy
(
non_bias_linear_func_proxy
,
kwargs
[
'bias'
])
return
bias_addition_proxy
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py
0 → 100644
View file @
08f2920e
from
.bias_addition_module
import
*
from
.conv
import
*
from
.linear
import
*
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
0 → 100644
View file @
08f2920e
import
operator
from
abc
import
ABC
,
abstractmethod
import
torch
import
torch.nn.functional
as
F
class
BiasAdditionModule
(
ABC
):
"""
This class is used to construct the restructure computation graph for
call_module node with bias addition inside.
"""
def
__init__
(
self
,
tracer
,
target
,
args
,
kwargs
,
substitute_func
):
self
.
tracer
=
tracer
self
.
target
=
target
self
.
args
=
args
self
.
kwargs
=
kwargs
self
.
substitute_func
=
substitute_func
self
.
weight_proxy
=
self
.
_create_weight_proxy
()
self
.
bias_proxy
=
self
.
_create_bias_proxy
()
def
_create_weight_proxy
(
self
):
"""
Create weight proxy, the node created by this proxy contains module weight.
Note: this function will be invoked during module initializing,
you should never call this function.
"""
weight_node_kind
=
'get_attr'
weight_node_target
=
self
.
target
+
'.weight'
weight_proxy
=
self
.
tracer
.
create_proxy
(
weight_node_kind
,
weight_node_target
,
(),
{})
return
weight_proxy
def
_create_bias_proxy
(
self
):
"""
Create bias proxy, the node created by this proxy contains module bias.
Note: this function will be invoked during module initializing,
you should never call this function.
"""
bias_node_kind
=
'get_attr'
bias_node_target
=
self
.
target
+
'.bias'
bias_proxy
=
self
.
tracer
.
create_proxy
(
bias_node_kind
,
bias_node_target
,
(),
{})
return
bias_proxy
@
abstractmethod
def
extract_kwargs_from_mod
(
self
):
"""
This method is used to extract the kwargs for non-bias computation.
For example:
The kwargs for conv2d module is {} because the attributes like 'padding' or 'groups' are
considered during module initilizing. However, we need to consider those attributes as kwargs
in F.conv2d.
"""
pass
def
create_non_bias_func_proxy
(
self
,
input_proxy
=
None
):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
node_kind
=
'call_function'
node_target
=
self
.
substitute_func
if
input_proxy
is
None
:
input_proxy
=
self
.
args
[
0
]
node_args
=
(
input_proxy
,
self
.
weight_proxy
)
node_kwargs
=
self
.
extract_kwargs_from_mod
()
non_bias_func_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
non_bias_func_proxy
def
create_bias_addition_proxy
(
self
,
non_bias_func_proxy
,
bias_proxy
):
"""
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
bias_add_node_kind
=
'call_function'
bias_add_node_target
=
operator
.
add
bias_add_args
=
(
non_bias_func_proxy
,
bias_proxy
)
bias_add_proxy
=
self
.
tracer
.
create_proxy
(
bias_add_node_kind
,
bias_add_node_target
,
tuple
(
bias_add_args
),
{})
return
bias_add_proxy
@
abstractmethod
def
generate
(
self
):
"""
This method is used to construct the whole restructure computation graph for call_module node with bias
addition inside.
A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node,
a bias reshape node if needed and a bias addition node.
Use Conv2d module as an example:
The origin node is:
%conv: call_module[target=conv](args = (%x,), kwargs = {})
Restructured graph is:
%conv_weight : [#users=1] = get_attr[target=conv.weight]
%conv_bias : [#users=1] = get_attr[target=conv.bias]
%conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
%view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
"""
pass
module_to_func_dict
=
{
torch
.
nn
.
Linear
:
F
.
linear
,
torch
.
nn
.
Conv1d
:
F
.
conv1d
,
torch
.
nn
.
Conv2d
:
F
.
conv2d
,
torch
.
nn
.
Conv3d
:
F
.
conv3d
,
}
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
0 → 100644
View file @
08f2920e
import
torch
import
torch.nn.functional
as
F
from
torch.nn.modules.utils
import
_pair
,
_reverse_repeat_tuple
,
_single
,
_triple
from
...registry
import
bias_addition_module
from
.bias_addition_module
import
BiasAdditionModule
@
bias_addition_module
.
register
(
torch
.
nn
.
Conv1d
)
@
bias_addition_module
.
register
(
torch
.
nn
.
Conv2d
)
@
bias_addition_module
.
register
(
torch
.
nn
.
Conv3d
)
class
BiasAdditionConv
(
BiasAdditionModule
):
def
extract_kwargs_from_mod
(
self
):
root
=
self
.
tracer
.
root
conv_module
=
root
.
get_submodule
(
self
.
target
)
kwarg_attributes
=
[
'groups'
,
'dilation'
,
'stride'
]
non_bias_kwargs
=
{}
for
attr_name
in
kwarg_attributes
:
if
hasattr
(
conv_module
,
attr_name
):
non_bias_kwargs
[
attr_name
]
=
getattr
(
conv_module
,
attr_name
)
if
conv_module
.
padding_mode
!=
"zeros"
:
#TODO: non zeros mode requires some extra processing for input
conv_type
=
type
(
conv_module
)
if
conv_type
==
"torch.nn.Conv1d"
:
padding_element
=
_single
(
0
)
elif
conv_type
==
"torch.nn.Conv2d"
:
padding_element
=
_pair
(
0
)
elif
conv_type
==
"torch.nn.Conv3d"
:
padding_element
=
_triple
(
0
)
non_bias_kwargs
[
'padding'
]
=
padding_element
else
:
non_bias_kwargs
[
'padding'
]
=
getattr
(
conv_module
,
'padding'
)
return
non_bias_kwargs
def
create_bias_reshape_proxy
(
self
,
dimensions
):
"""
This method is used to reshape the bias node in order to make bias and
output of non-bias convolution broadcastable.
"""
bias_shape
=
[
1
]
*
(
dimensions
-
1
)
bias_shape
[
0
]
=
-
1
bias_reshape_node_kind
=
'call_method'
bias_reshape_node_target
=
'view'
bias_reshape_node_args
=
(
self
.
bias_proxy
,
torch
.
Size
(
bias_shape
))
bias_reshape_proxy
=
self
.
tracer
.
create_proxy
(
bias_reshape_node_kind
,
bias_reshape_node_target
,
bias_reshape_node_args
,
{})
return
bias_reshape_proxy
def
generate
(
self
):
non_bias_conv_func_proxy
=
self
.
create_non_bias_func_proxy
()
output_dims
=
non_bias_conv_func_proxy
.
meta_data
.
dim
()
bias_reshape_proxy
=
self
.
create_bias_reshape_proxy
(
output_dims
)
bias_addition_proxy
=
self
.
create_bias_addition_proxy
(
non_bias_conv_func_proxy
,
bias_reshape_proxy
)
return
bias_addition_proxy
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
0 → 100644
View file @
08f2920e
import
torch
import
torch.nn.functional
as
F
from
...registry
import
bias_addition_module
from
.bias_addition_module
import
BiasAdditionModule
@
bias_addition_module
.
register
(
torch
.
nn
.
Linear
)
class
BiasAdditionLinear
(
BiasAdditionModule
):
def
extract_kwargs_from_mod
(
self
):
return
{}
def
generate
(
self
):
non_bias_linear_func_proxy
=
self
.
create_non_bias_func_proxy
()
bias_addition_proxy
=
self
.
create_bias_addition_proxy
(
non_bias_linear_func_proxy
,
self
.
bias_proxy
)
return
bias_addition_proxy
colossalai/fx/tracer/experimental.py
0 → 100644
View file @
08f2920e
import
enum
import
functools
import
inspect
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch.fx
import
Graph
,
Node
,
Proxy
,
Tracer
from
torch.utils._pytree
import
tree_map
from
colossalai.fx
import
ColoGraphModule
,
compatibility
,
is_compatible_with_meta
if
is_compatible_with_meta
():
from
colossalai.fx.profiler
import
MetaTensor
Target
=
Union
[
Callable
[...,
Any
],
str
]
Argument
=
Optional
[
Union
[
Tuple
[
Any
,
...],
# actually Argument, but mypy can't represent recursive types
List
[
Any
],
# actually Argument
Dict
[
str
,
Any
],
# actually Argument
slice
,
# Slice[Argument, Argument, Argument], but slice is not a templated type in typing
'Node'
,]]
_CScriptMethod
=
[
'add'
,
'mul'
,
'sub'
,
'div'
]
_TorchNewMethod
=
[
"arange"
,
"zeros"
,
"zeros_like"
,
"ones"
,
"ones_like"
,
"full"
,
"full_like"
,
"empty"
,
"empty_like"
,
"eye"
,
"tensor"
,
"finfo"
]
_TensorPropertyMethod
=
[
"dtype"
,
"shape"
,
"device"
,
"requires_grad"
,
"grad"
,
"grad_fn"
,
"data"
]
def
_truncate_suffix
(
s
:
str
):
import
re
return
re
.
sub
(
r
'_\d+$'
,
''
,
s
)
def
is_element_in_list
(
elements
:
Union
[
List
[
Any
],
Any
],
list_
:
List
[
Any
]):
if
isinstance
(
elements
,
(
tuple
,
list
,
set
)):
for
ele
in
elements
:
if
ele
not
in
list_
:
return
False
,
ele
else
:
if
elements
not
in
list_
:
return
False
,
elements
return
True
,
None
def
default_device
():
return
torch
.
device
(
'cuda:0'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
@
compatibility
(
is_backward_compatible
=
False
)
class
ColoProxy
(
Proxy
):
def
__init__
(
self
,
*
args
,
data
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_data
=
data
@
property
def
data
(
self
):
return
self
.
_data
@
data
.
setter
def
data
(
self
,
args
):
wrap_fn
=
lambda
x
:
MetaTensor
(
x
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
self
.
_data
=
tree_map
(
wrap_fn
,
args
)
@
classmethod
def
__torch_function__
(
cls
,
orig_method
,
types
,
args
=
(),
kwargs
=
None
):
proxy
=
cls
.
from_torch_proxy
(
super
().
__torch_function__
(
orig_method
,
types
,
args
,
kwargs
))
unwrap_fn
=
lambda
p
:
p
.
data
if
isinstance
(
p
,
ColoProxy
)
else
p
kwargs
=
{}
if
kwargs
is
None
else
kwargs
if
proxy
.
data
is
None
:
proxy
.
data
=
orig_method
(
*
tree_map
(
unwrap_fn
,
args
),
**
tree_map
(
unwrap_fn
,
kwargs
))
return
proxy
@
classmethod
def
from_torch_proxy
(
cls
,
proxy
:
Proxy
):
return
cls
(
proxy
.
node
,
proxy
.
tracer
)
def
__repr__
(
self
):
return
f
"ColoProxy(
{
self
.
node
.
name
}
, data=
{
self
.
data
}
)"
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__int__
(
self
):
return
int
(
self
.
data
)
def
__index__
(
self
):
try
:
return
int
(
self
.
data
)
except
:
return
torch
.
zeros
(
self
.
data
.
shape
,
dtype
=
torch
.
bool
).
numpy
().
__index__
()
def
__float__
(
self
):
return
float
(
self
.
data
)
def
__bool__
(
self
):
return
self
.
data
def
__getattr__
(
self
,
k
):
return
ColoAttribute
(
self
,
k
,
getattr
(
self
.
_data
,
k
,
None
))
def
__contains__
(
self
,
key
):
if
self
.
node
.
op
==
"placeholder"
:
# this is used to handle like
# if x in kwargs
# we don't handle this case for now
return
False
return
super
().
__contains__
(
key
)
def
__isinstancecheck__
(
self
,
type
):
return
isinstance
(
self
.
data
,
type
)
@
property
def
shape
(
self
):
return
self
.
data
.
shape
@
property
def
ndim
(
self
):
return
self
.
data
.
ndim
@
property
def
device
(
self
):
proxy
=
self
.
tracer
.
create_proxy
(
'call_function'
,
getattr
,
(
self
,
'device'
),
{})
proxy
.
data
=
self
.
data
.
device
return
proxy
@
property
def
dtype
(
self
):
proxy
=
self
.
tracer
.
create_proxy
(
'call_function'
,
getattr
,
(
self
,
'dtype'
),
{})
proxy
.
data
=
self
.
data
.
dtype
return
proxy
def
to
(
self
,
*
args
,
**
kwargs
):
return
self
.
tracer
.
create_proxy
(
'call_method'
,
'to'
,
(
self
,
*
args
),
{
**
kwargs
})
def
cpu
(
self
,
*
args
,
**
kwargs
):
return
self
.
tracer
.
create_proxy
(
'call_method'
,
'cpu'
,
(
self
,
*
args
),
{
**
kwargs
})
def
cuda
(
self
,
*
args
,
**
kwargs
):
return
self
.
tracer
.
create_proxy
(
'call_method'
,
'cuda'
,
(
self
,
*
args
),
{
**
kwargs
})
@
compatibility
(
is_backward_compatible
=
False
)
class
ColoAttribute
(
ColoProxy
):
def
__init__
(
self
,
root
,
attr
:
str
,
data
=
None
):
self
.
root
=
root
self
.
attr
=
attr
self
.
tracer
=
root
.
tracer
self
.
_data
=
data
self
.
_node
:
Optional
[
Node
]
=
None
@
property
def
node
(
self
):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if
self
.
_node
is
None
:
self
.
_node
=
self
.
tracer
.
create_proxy
(
'call_function'
,
getattr
,
(
self
.
root
,
self
.
attr
),
{}).
node
return
self
.
_node
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
tracer
.
create_proxy
(
'call_method'
,
self
.
attr
,
(
self
.
root
,)
+
args
,
kwargs
)
def
__repr__
(
self
):
return
f
"ColoAttribute(
{
self
.
node
.
name
}
, attr=
{
self
.
attr
}
)"
@
compatibility
(
is_backward_compatible
=
False
)
class
ColoTracer
(
Tracer
):
def
__init__
(
self
,
trace_act_ckpt
:
bool
=
False
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_disable_module_getattr
=
False
self
.
proxy_buffer_attributes
=
True
def
proxy
(
self
,
node
:
Node
)
->
'ColoProxy'
:
return
ColoProxy
(
node
,
self
)
def
create_proxy
(
self
,
kind
:
str
,
target
:
Target
,
args
:
Tuple
[
Any
,
...],
kwargs
:
Dict
[
str
,
Any
],
name
:
Optional
[
str
]
=
None
,
type_expr
:
Optional
[
Any
]
=
None
,
proxy_factory_fn
:
Callable
[[
Node
],
'Proxy'
]
=
None
):
proxy
:
ColoProxy
=
super
().
create_proxy
(
kind
,
target
,
args
,
kwargs
,
name
,
type_expr
,
proxy_factory_fn
)
unwrap_fn
=
lambda
p
:
p
.
data
if
isinstance
(
p
,
ColoProxy
)
else
p
if
kind
==
'placeholder'
:
proxy
.
data
=
self
.
meta_args
[
target
]
if
target
in
self
.
meta_args
else
self
.
concrete_args
.
get
(
_truncate_suffix
(
target
),
None
)
elif
kind
==
'get_attr'
:
self
.
_disable_module_getattr
=
True
try
:
attr_itr
=
self
.
root
atoms
=
target
.
split
(
"."
)
for
atom
in
atoms
:
attr_itr
=
getattr
(
attr_itr
,
atom
)
proxy
.
data
=
attr_itr
finally
:
self
.
_disable_module_getattr
=
False
elif
kind
==
'call_function'
:
proxy
.
data
=
target
(
*
tree_map
(
unwrap_fn
,
args
),
**
tree_map
(
unwrap_fn
,
kwargs
))
elif
kind
==
'call_method'
:
self
.
_disable_module_getattr
=
True
try
:
if
target
==
'__call__'
:
proxy
.
data
=
unwrap_fn
(
args
[
0
])(
*
tree_map
(
unwrap_fn
,
args
[
1
:]),
**
tree_map
(
unwrap_fn
,
kwargs
))
else
:
if
target
not
in
_TensorPropertyMethod
:
proxy
.
_data
=
getattr
(
unwrap_fn
(
args
[
0
]),
target
)(
*
tree_map
(
unwrap_fn
,
args
[
1
:]),
**
tree_map
(
unwrap_fn
,
kwargs
))
finally
:
self
.
_disable_module_getattr
=
False
elif
kind
==
'call_module'
:
mod
=
self
.
root
.
get_submodule
(
target
)
unwrap_fn
=
lambda
p
:
p
.
data
if
isinstance
(
p
,
ColoProxy
)
else
p
self
.
_disable_module_getattr
=
True
try
:
proxy
.
data
=
mod
.
forward
(
*
tree_map
(
unwrap_fn
,
args
),
**
tree_map
(
unwrap_fn
,
kwargs
))
finally
:
self
.
_disable_module_getattr
=
True
return
proxy
def
trace
(
self
,
root
:
torch
.
nn
.
Module
,
concrete_args
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
meta_args
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
)
->
Graph
:
if
meta_args
is
None
:
meta_args
=
{}
if
concrete_args
is
None
:
concrete_args
=
{}
# check concrete and meta args have valid names
sig
=
inspect
.
signature
(
root
.
forward
)
sig_names
=
set
(
sig
.
parameters
.
keys
())
meta_arg_names
=
set
(
meta_args
.
keys
())
# update concrete args with default values
non_meta_arg_names
=
sig_names
-
meta_arg_names
for
k
,
v
in
sig
.
parameters
.
items
():
if
k
in
non_meta_arg_names
and
\
k
not
in
concrete_args
and
\
v
.
default
is
not
inspect
.
Parameter
.
empty
:
concrete_args
[
k
]
=
v
.
default
# get non concrete arg names
concrete_arg_names
=
set
(
concrete_args
.
keys
())
non_concrete_arg_names
=
sig_names
-
concrete_arg_names
def
_check_arg_name_valid
(
names
):
success
,
element
=
is_element_in_list
(
names
,
sig_names
)
if
not
success
:
raise
KeyError
(
f
"argument
{
element
}
is not found in the signature of
{
root
.
__class__
.
__name__
}
's forward function"
)
_check_arg_name_valid
(
meta_arg_names
)
_check_arg_name_valid
(
concrete_arg_names
)
self
.
concrete_args
=
concrete_args
self
.
meta_args
=
meta_args
with
_TorchTensorOverride
(
self
):
self
.
graph
=
super
().
trace
(
root
,
concrete_args
=
concrete_args
)
self
.
graph
.
lint
()
return
self
.
graph
def
_post_check
(
self
,
non_concrete_arg_names
:
Set
[
str
]):
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
for
node
in
self
.
graph
.
nodes
:
if
node
.
op
==
"placeholder"
:
# Removing default values for inputs as the forward pass will fail with them.
if
node
.
target
in
non_concrete_arg_names
:
node
.
args
=
()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node
.
type
=
torch
.
Tensor
# It is a concrete arg so it is not used and should be removed.
else
:
if
hasattr
(
torch
.
fx
.
_symbolic_trace
,
"_assert_is_none"
):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete
=
[]
for
user
in
node
.
users
:
if
user
.
target
==
torch
.
fx
.
_symbolic_trace
.
_assert_is_none
:
to_delete
.
append
(
user
)
for
user
in
to_delete
:
self
.
graph
.
erase_node
(
user
)
self
.
graph
.
erase_node
(
node
)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if
node
.
op
==
"output"
:
node
.
type
=
None
self
.
graph
.
lint
()
def
_module_getattr
(
self
,
attr
,
attr_val
,
parameter_proxy_cache
):
if
getattr
(
self
,
"_disable_module_getattr"
,
False
):
return
attr_val
def
maybe_get_proxy_for_attr
(
attr_val
,
collection_to_search
,
parameter_proxy_cache
):
for
n
,
p
in
collection_to_search
:
if
attr_val
is
p
:
if
n
not
in
parameter_proxy_cache
:
kwargs
=
{}
if
'proxy_factory_fn'
in
inspect
.
signature
(
self
.
create_proxy
).
parameters
:
kwargs
[
'proxy_factory_fn'
]
=
(
None
if
not
self
.
param_shapes_constant
else
lambda
node
:
ColoProxy
(
self
,
node
,
n
,
attr_val
))
val_proxy
=
self
.
create_proxy
(
'get_attr'
,
n
,
(),
{},
**
kwargs
)
# type: ignore[arg-type]
parameter_proxy_cache
[
n
]
=
val_proxy
return
parameter_proxy_cache
[
n
]
return
None
if
self
.
proxy_buffer_attributes
and
isinstance
(
attr_val
,
torch
.
Tensor
):
maybe_buffer_proxy
=
maybe_get_proxy_for_attr
(
attr_val
,
self
.
root
.
named_buffers
(),
parameter_proxy_cache
)
if
maybe_buffer_proxy
is
not
None
:
return
maybe_buffer_proxy
if
isinstance
(
attr_val
,
torch
.
nn
.
Parameter
):
maybe_parameter_proxy
=
maybe_get_proxy_for_attr
(
attr_val
,
self
.
root
.
named_parameters
(),
parameter_proxy_cache
)
if
maybe_parameter_proxy
is
not
None
:
return
maybe_parameter_proxy
return
attr_val
@
compatibility
(
is_backward_compatible
=
True
)
def
symbolic_trace
(
root
:
Union
[
torch
.
nn
.
Module
,
Callable
[...,
Any
]],
concrete_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
meta_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
ColoGraphModule
:
if
is_compatible_with_meta
():
if
meta_args
is
not
None
:
root
.
to
(
default_device
())
wrap_fn
=
lambda
x
:
MetaTensor
(
x
,
fake_device
=
default_device
())
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
graph
=
ColoTracer
().
trace
(
root
,
concrete_args
=
concrete_args
,
meta_args
=
tree_map
(
wrap_fn
,
meta_args
))
root
.
cpu
()
else
:
graph
=
Tracer
().
trace
(
root
,
concrete_args
=
concrete_args
)
else
:
from
.tracer
import
ColoTracer
as
OrigColoTracer
graph
=
OrigColoTracer
().
trace
(
root
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
name
=
root
.
__class__
.
__name__
if
isinstance
(
root
,
torch
.
nn
.
Module
)
else
root
.
__name__
return
ColoGraphModule
(
root
,
graph
,
name
)
@
compatibility
(
is_backward_compatible
=
False
)
class
_TorchTensorOverride
(
object
):
def
__init__
(
self
,
tracer
:
Tracer
):
self
.
overrides
=
{}
self
.
tracer
=
tracer
def
__enter__
(
self
):
def
wrap_tensor_method
(
target
):
@
functools
.
wraps
(
target
)
def
wrapper
(
*
args
,
**
kwargs
):
is_proxy
=
any
(
isinstance
(
p
,
ColoProxy
)
for
p
in
args
)
|
any
(
isinstance
(
p
,
ColoProxy
)
for
p
in
kwargs
.
values
())
if
is_proxy
:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self
.
tracer
.
_disable_module_getattr
=
True
try
:
proxy
=
self
.
tracer
.
create_proxy
(
'call_function'
,
target
,
args
,
kwargs
)
finally
:
self
.
tracer
.
_disable_module_getattr
=
False
return
proxy
else
:
return
target
(
*
args
,
**
kwargs
)
return
wrapper
,
target
self
.
overrides
=
{
target
:
wrap_tensor_method
(
getattr
(
torch
,
target
))
for
target
in
_TorchNewMethod
if
callable
(
getattr
(
torch
,
target
))
}
for
name
,
(
wrapper
,
orig
)
in
self
.
overrides
.
items
():
setattr
(
torch
,
name
,
wrapper
)
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
for
name
,
(
wrapper
,
orig
)
in
self
.
overrides
.
items
():
setattr
(
torch
,
name
,
orig
)
colossalai/fx/tracer/meta_patch/__init__.py
0 → 100644
View file @
08f2920e
from
.patched_function
import
*
from
.patched_module
import
*
Prev
1
…
11
12
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment