Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e532679c
Commit
e532679c
authored
Jan 10, 2023
by
oahzxl
Browse files
Merge branch 'main' of
https://github.com/oahzxl/ColossalAI
into chunk
parents
c1492e50
7d5640b9
Changes
461
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1426 additions
and
205 deletions
+1426
-205
colossalai/fx/profiler/memory_utils.py
colossalai/fx/profiler/memory_utils.py
+71
-0
colossalai/fx/profiler/opcount.py
colossalai/fx/profiler/opcount.py
+123
-117
colossalai/fx/profiler/profiler.py
colossalai/fx/profiler/profiler.py
+16
-12
colossalai/fx/profiler/shard_utils.py
colossalai/fx/profiler/shard_utils.py
+30
-65
colossalai/fx/profiler/tensor.py
colossalai/fx/profiler/tensor.py
+16
-5
colossalai/fx/tracer/__init__.py
colossalai/fx/tracer/__init__.py
+5
-2
colossalai/fx/tracer/_meta_trace.py
colossalai/fx/tracer/_meta_trace.py
+1
-3
colossalai/fx/tracer/_symbolic_trace.py
colossalai/fx/tracer/_symbolic_trace.py
+54
-0
colossalai/fx/tracer/bias_addition_patch/__init__.py
colossalai/fx/tracer/bias_addition_patch/__init__.py
+2
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py
...addition_patch/patched_bias_addition_function/__init__.py
+4
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
...s_addition_patch/patched_bias_addition_function/addbmm.py
+75
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
...as_addition_patch/patched_bias_addition_function/addmm.py
+60
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
.../patched_bias_addition_function/bias_addition_function.py
+115
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
...s_addition_patch/patched_bias_addition_function/linear.py
+25
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py
...s_addition_patch/patched_bias_addition_module/__init__.py
+3
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
...atch/patched_bias_addition_module/bias_addition_module.py
+111
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
.../bias_addition_patch/patched_bias_addition_module/conv.py
+56
-0
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
...ias_addition_patch/patched_bias_addition_module/linear.py
+17
-0
colossalai/fx/tracer/experimental.py
colossalai/fx/tracer/experimental.py
+642
-0
colossalai/fx/tracer/meta_patch/__init__.py
colossalai/fx/tracer/meta_patch/__init__.py
+0
-1
No files found.
Too many changes to show.
To preserve performance only
461 of 461+
files are displayed.
Plain diff
Email patch
colossalai/fx/profiler/memory_utils.py
0 → 100644
View file @
e532679c
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
torch
from
torch.fx
import
GraphModule
,
Node
from
.._compatibility
import
compatibility
,
is_compatible_with_meta
__all__
=
[
'activation_size'
,
'parameter_size'
,
'is_inplace'
]
@
compatibility
(
is_backward_compatible
=
True
)
def
activation_size
(
out
:
Union
[
torch
.
Tensor
,
Dict
,
List
,
Tuple
,
int
])
->
int
:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`.
Returns:
int: The activation size, unit is byte.
"""
act_size
=
0
if
isinstance
(
out
,
torch
.
Tensor
):
if
out
.
is_quantized
:
act_size
+=
out
.
numel
()
*
torch
.
_empty_affine_quantized
([],
dtype
=
out
.
dtype
).
element_size
()
else
:
act_size
+=
out
.
numel
()
*
torch
.
tensor
([],
dtype
=
out
.
dtype
).
element_size
()
elif
isinstance
(
out
,
dict
):
value_list
=
[
v
for
_
,
v
in
out
.
items
()]
act_size
+=
activation_size
(
value_list
)
elif
isinstance
(
out
,
tuple
)
or
isinstance
(
out
,
list
)
or
isinstance
(
out
,
set
):
for
element
in
out
:
act_size
+=
activation_size
(
element
)
return
act_size
@
compatibility
(
is_backward_compatible
=
True
)
def
parameter_size
(
mod
:
torch
.
nn
.
Module
)
->
int
:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`.
Returns:
int: The parameter size, unit is byte.
"""
param_size
=
0
for
param
in
mod
.
parameters
():
param_size
+=
param
.
numel
()
*
torch
.
tensor
([],
dtype
=
param
.
dtype
).
element_size
()
return
param_size
def
is_inplace
(
n
:
Node
):
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
inplace
=
False
if
n
.
op
==
"call_function"
:
inplace
=
n
.
kwargs
.
get
(
"inplace"
,
False
)
if
is_compatible_with_meta
():
from
.constants
import
ALIAS_ATEN
if
n
.
target
in
ALIAS_ATEN
:
inplace
=
True
elif
n
.
op
==
"call_module"
:
inplace
=
getattr
(
n
.
graph
.
owning_module
.
get_submodule
(
n
.
target
),
"inplace"
,
False
)
return
inplace
colossalai/fx/profiler/opcount.py
View file @
e532679c
...
...
@@ -7,6 +7,7 @@ from numbers import Number
from
typing
import
Any
,
Callable
,
List
import
torch
from
packaging
import
version
aten
=
torch
.
ops
.
aten
...
...
@@ -32,7 +33,7 @@ def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
# inputs is a list of length 3.
input_shapes
=
[
v
.
shape
for
v
in
inputs
[
1
:
3
]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [
batch size
, output feature dimension]
# input_shapes[1]: [
input feature dimension
, output feature dimension]
assert
len
(
input_shapes
[
0
])
==
2
,
input_shapes
[
0
]
assert
len
(
input_shapes
[
1
])
==
2
,
input_shapes
[
1
]
batch_size
,
input_dim
=
input_shapes
[
0
]
...
...
@@ -188,7 +189,8 @@ def zero_flop_jit(*args):
return
0
flop_mapping
=
{
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
):
flop_mapping
=
{
# gemm
aten
.
mm
.
default
:
matmul_flop_jit
,
aten
.
matmul
.
default
:
matmul_flop_jit
,
...
...
@@ -228,9 +230,9 @@ flop_mapping = {
aten
.
_adaptive_avg_pool3d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
embedding_dense_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
embedding
.
default
:
elementwise_flop_counter
(
1
,
0
),
}
}
elementwise_flop_aten
=
[
elementwise_flop_aten
=
[
# basic op
aten
.
add
.
Tensor
,
aten
.
add_
.
Tensor
,
...
...
@@ -275,13 +277,12 @@ elementwise_flop_aten = [
# dropout
aten
.
native_dropout
.
default
,
aten
.
native_dropout_backward
.
default
,
]
for
op
in
elementwise_flop_aten
:
]
for
op
in
elementwise_flop_aten
:
flop_mapping
[
op
]
=
elementwise_flop_counter
(
1
,
0
)
# TODO: this will be removed in future
zero_flop_aten
=
[
# TODO: this will be removed in future
zero_flop_aten
=
[
aten
.
as_strided
.
default
,
aten
.
as_strided_
.
default
,
aten
.
bernoulli_
.
float
,
...
...
@@ -312,7 +313,12 @@ zero_flop_aten = [
aten
.
where
.
self
,
aten
.
zero_
.
default
,
aten
.
zeros_like
.
default
,
]
]
for
op
in
zero_flop_aten
:
for
op
in
zero_flop_aten
:
flop_mapping
[
op
]
=
zero_flop_jit
else
:
flop_mapping
=
{}
elementwise_flop_aten
=
{}
zero_flop_aten
=
{}
colossalai/fx/profiler/profiler.py
View file @
e532679c
...
...
@@ -11,7 +11,7 @@ from torch.utils._pytree import tree_map
from
.._compatibility
import
compatibility
from
.constants
import
ALIAS_ATEN
,
OUTPUT_SAVED_MOD
,
OUTPUT_SAVED_OPS
from
.dataflow
import
GraphInfo
,
Phase
,
autograd_graph_analysis
,
is_phase
from
.memory
import
activation_size
,
parameter_size
from
.memory
_utils
import
activation_size
,
parameter_size
from
.opcount
import
flop_mapping
from
.tensor
import
MetaTensor
...
...
@@ -232,12 +232,12 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def
pack
(
x
):
global
cache
,
do_not_cache
if
isinstance
(
x
,
FlopTensor
)
and
not
x
.
_tensor
.
uuid
in
cache
:
if
isinstance
(
x
,
FlopTensor
)
and
not
x
.
_tensor
.
data_ptr
()
in
cache
:
tensor
=
x
.
_tensor
.
detach
()
tensor
.
uuid
=
x
.
_tensor
.
uuid
tensor
.
data_ptr
=
x
.
_tensor
.
data_ptr
x
.
_node
.
meta
[
'saved_tensor'
]
+=
[
tensor
]
if
not
do_not_cache
:
cache
.
add
(
x
.
_tensor
.
uuid
)
cache
.
add
(
x
.
_tensor
.
data_ptr
()
)
return
x
def
unpack
(
x
):
...
...
@@ -270,7 +270,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def
extract_tensor
(
x
:
Any
):
if
isinstance
(
x
,
MetaTensor
):
tensor
=
x
.
_tensor
.
detach
()
tensor
.
uuid
=
x
.
_tensor
.
uuid
tensor
.
data_ptr
=
x
.
_tensor
.
data_ptr
return
tensor
if
not
isinstance
(
x
,
torch
.
finfo
):
return
x
...
...
@@ -328,6 +328,8 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
out
,
meta
=
_profile_concrete
(
func
,
*
args
,
**
kwargs
)
if
inplace
:
kwargs
[
'inplace'
]
=
True
meta
.
bwd_mem_tmp
=
0
meta
.
bwd_mem_out
=
0
do_not_cache
=
False
meta
.
bwd_mem_out
-=
param_size
...
...
@@ -394,6 +396,8 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
out
,
meta
=
_profile_concrete
(
func
,
*
args
,
**
kwargs
)
if
inplace
:
module
.
inplace
=
True
meta
.
bwd_mem_tmp
=
0
meta
.
bwd_mem_out
=
0
do_not_cache
=
False
# grad for param will not be counted
...
...
colossalai/fx/profiler/
memory
.py
→
colossalai/fx/profiler/
shard_utils
.py
View file @
e532679c
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
torch
from
torch.fx
import
GraphModule
,
Node
from
torch.fx
import
Node
from
.._compatibility
import
compatibility
,
is_compatible_with_meta
from
.memory_utils
import
activation_size
if
is_compatible_with_meta
():
from
.constants
import
OUTPUT_SAVED_MOD
,
OUTPUT_SAVED_OPS
__all__
=
[
'activation_size'
,
'parameter_size'
,
'is_inplace'
,
"calculate_fwd_in"
,
"calculate_fwd_tmp"
,
"calculate_fwd_out"
]
@
compatibility
(
is_backward_compatible
=
True
)
def
activation_size
(
out
:
Union
[
torch
.
Tensor
,
Dict
,
List
,
Tuple
,
int
])
->
int
:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
Returns:
int: The activation size
"""
act_size
=
0
if
isinstance
(
out
,
torch
.
Tensor
):
act_size
+=
out
.
numel
()
*
torch
.
tensor
([],
dtype
=
out
.
dtype
).
element_size
()
elif
isinstance
(
out
,
dict
):
value_list
=
[
v
for
_
,
v
in
out
.
items
()]
act_size
+=
activation_size
(
value_list
)
elif
isinstance
(
out
,
tuple
)
or
isinstance
(
out
,
list
)
or
isinstance
(
out
,
set
):
for
element
in
out
:
act_size
+=
activation_size
(
element
)
return
act_size
@
compatibility
(
is_backward_compatible
=
True
)
def
parameter_size
(
mod
:
torch
.
nn
.
Module
)
->
int
:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
Returns:
int: The parameter size
"""
param_size
=
0
for
param
in
mod
.
parameters
():
param_size
+=
param
.
numel
()
*
torch
.
tensor
([],
dtype
=
param
.
dtype
).
element_size
()
return
param_size
__all__
=
[
"calculate_fwd_in"
,
"calculate_fwd_tmp"
,
"calculate_fwd_out"
]
@
compatibility
(
is_backward_compatible
=
False
)
def
calculate_fwd_in
(
n
:
Node
)
->
int
:
"""A helper function to calculate `fwd_in`
"""A helper function to calculate `fwd_in`
(with sharding spec)
Args:
n (Node): a node from the graph
...
...
@@ -60,11 +20,13 @@ def calculate_fwd_in(n: Node) -> int:
Returns:
fwd_in (int): the result of `fwd_in`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
return
activation_size
(
n
.
meta
[
"fwd_in"
])
@
compatibility
(
is_backward_compatible
=
False
)
def
calculate_fwd_tmp
(
n
:
Node
)
->
int
:
"""A helper function to calculate `fwd_tmp`
"""A helper function to calculate `fwd_tmp`
(with sharding spec)
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
Args:
...
...
@@ -74,6 +36,7 @@ def calculate_fwd_tmp(n: Node) -> int:
fwd_tmp (int): the result of `fwd_tmp`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
def
is_relu_like_node
(
n
:
Node
)
->
bool
:
"""Check if a node is a ReLU-like node.
ReLU-like nodes have the following properties:
...
...
@@ -107,8 +70,9 @@ def calculate_fwd_tmp(n: Node) -> int:
return
0
@
compatibility
(
is_backward_compatible
=
False
)
def
calculate_fwd_out
(
n
:
Node
)
->
int
:
"""A helper function to calculate `fwd_out`
"""A helper function to calculate `fwd_out`
(with sharding spec)
Args:
n (Node): a node from the graph
...
...
@@ -117,33 +81,34 @@ def calculate_fwd_out(n: Node) -> int:
fwd_out (int): the result of `fwd_out`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
def
intersect
(
a
,
b
):
return
{
k
:
a
[
k
]
for
k
in
a
if
k
in
b
}
fwd_in
=
dict
()
for
u
in
n
.
users
:
fwd_in
.
update
({
x
.
uuid
:
x
for
x
in
u
.
meta
[
"fwd_in"
]
if
isinstance
(
x
,
torch
.
Tensor
)
and
hasattr
(
x
,
'uuid'
)
})
fwd_out
=
{
x
.
uuid
:
x
for
x
in
n
.
meta
[
"fwd_out"
]
if
isinstance
(
x
,
torch
.
Tensor
)
and
hasattr
(
x
,
'uuid'
)
}
fwd_in
.
update
({
x
.
data_ptr
()
:
x
for
x
in
u
.
meta
[
"fwd_in"
]
if
isinstance
(
x
,
torch
.
Tensor
)})
fwd_out
=
{
x
.
data_ptr
()
:
x
for
x
in
n
.
meta
[
"fwd_out"
]
if
isinstance
(
x
,
torch
.
Tensor
)}
return
activation_size
(
intersect
(
fwd_in
,
fwd_out
))
def
is_inplace
(
n
:
Node
):
"""Get the inplace argument from torch.fx.Node
def
calculate_fwd_time
(
n
:
Node
)
->
float
:
"""A helper function to calculate `fwd_time` (with sharding spec)
Args:
node (Node): torch.fx.Node
n (Node): a node from the graph
Returns:
fwd_time (float): the result of `fwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return
n
.
meta
[
"fwd_time"
]
def
calculate_bwd_time
(
n
:
Node
)
->
float
:
"""A helper function to calculate `bwd_time` (with sharding spec)
Args:
n (Node): a node from the graph
Returns:
b
ool: indicates whether this op is inplace
b
wd_time (float): the result of `bwd_time`
"""
inplace
=
False
if
n
.
op
==
"call_function"
:
inplace
=
n
.
kwargs
.
get
(
"inplace"
,
False
)
if
is_compatible_with_meta
():
from
.constants
import
ALIAS_ATEN
if
n
.
target
in
ALIAS_ATEN
:
inplace
=
True
elif
n
.
op
==
"call_module"
:
inplace
=
getattr
(
n
.
graph
.
owning_module
.
get_submodule
(
n
.
target
),
"inplace"
,
False
)
return
inplace
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return
n
.
meta
[
"bwd_time"
]
colossalai/fx/profiler/tensor.py
View file @
e532679c
...
...
@@ -12,10 +12,11 @@ from .constants import ALIAS_ATEN
__all__
=
[
'MetaTensor'
]
def
set_
uuid
(
x
):
def
set_
data_ptr
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
if
not
hasattr
(
x
,
'uuid'
):
setattr
(
x
,
'uuid'
,
uuid
.
uuid4
())
if
not
x
.
data_ptr
():
data_ptr
=
uuid
.
uuid4
()
x
.
data_ptr
=
lambda
:
data_ptr
@
compatibility
(
is_backward_compatible
=
False
)
...
...
@@ -53,7 +54,7 @@ class MetaTensor(torch.Tensor):
if
not
r
.
_tensor
.
is_meta
:
r
.
_tensor
=
r
.
_tensor
.
to
(
torch
.
device
(
'meta'
))
# only tensor not on `meta` should be copied to `meta`
set_
uuid
(
r
.
_tensor
)
set_
data_ptr
(
r
.
_tensor
)
return
r
def
__repr__
(
self
):
...
...
@@ -88,7 +89,7 @@ class MetaTensor(torch.Tensor):
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
# of the input
if
func
in
ALIAS_ATEN
:
setattr
(
out
,
'uuid'
,
args
[
0
].
uuid
)
out
.
data_ptr
=
args
[
0
].
data_ptr
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
...
...
@@ -127,3 +128,13 @@ class MetaTensor(torch.Tensor):
if
device
is
not
None
:
result
=
MetaTensor
(
result
,
fake_device
=
device
)
return
result
def
cpu
(
self
,
*
args
,
**
kwargs
):
if
self
.
device
.
type
==
'cpu'
:
return
self
.
to
(
*
args
,
**
kwargs
)
return
self
.
to
(
*
args
,
device
=
'cpu'
,
**
kwargs
)
def
cuda
(
self
,
*
args
,
**
kwargs
):
if
self
.
device
.
type
==
'cuda'
:
return
self
.
to
(
*
args
,
**
kwargs
)
return
self
.
to
(
*
args
,
device
=
'cuda'
,
**
kwargs
)
colossalai/fx/tracer/__init__.py
View file @
e532679c
from
.tracer
import
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_function.python_ops
import
operator_getitem
from
._meta_trace
import
meta_trace
from
._symbolic_trace
import
symbolic_trace
from
.tracer
import
ColoTracer
colossalai/fx/tracer/_meta_trace.py
View file @
e532679c
from
colossalai.fx.profiler.memory
import
activation_size
import
torch
from
torch.fx
import
Node
,
Graph
from
torch.fx.graph
import
_Namespace
from
torch.fx
import
Graph
,
Node
from
torch.utils._pytree
import
tree_map
...
...
colossalai/fx/tracer/_symbolic_trace.py
0 → 100644
View file @
e532679c
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Union
import
torch
from
colossalai.fx
import
ColoGraphModule
from
colossalai.fx._compatibility
import
compatibility
from
.tracer
import
ColoTracer
@
compatibility
(
is_backward_compatible
=
True
)
def
symbolic_trace
(
root
:
Union
[
torch
.
nn
.
Module
,
Callable
[...,
Any
]],
concrete_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
meta_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
ColoGraphModule
:
"""
Symbolic tracing API
Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule``
constructed by recording operations seen while tracing through ``root``.
With ``meta_args``, we can trace the model that are untraceable subject to control flow. If specified using
``meta_args`` only, the tracing can be done ahead of time.
Note that ``meta_args`` are kwargs, which contains the key of the argument's names and the value of the
argument's values.
Uses:
>>> model = ...
# if this works
>>> gm = symbolic_trace(model, concrete_args=concrete_args)
# else try this
>>> gm = symbolic_trace(model, concrete_args=concrete_args, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted
into a Graph representation.
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be used for tracing.
meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``.
Defaults to None.
Returns:
ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``.
Warnings:
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
graph
=
ColoTracer
().
trace
(
root
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
name
=
root
.
__class__
.
__name__
if
isinstance
(
root
,
torch
.
nn
.
Module
)
else
root
.
__name__
return
ColoGraphModule
(
root
,
graph
,
name
)
colossalai/fx/tracer/bias_addition_patch/__init__.py
0 → 100644
View file @
e532679c
from
.patched_bias_addition_function
import
*
from
.patched_bias_addition_module
import
*
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/__init__.py
0 → 100644
View file @
e532679c
from
.addbmm
import
Addbmm
from
.addmm
import
Addmm
from
.bias_addition_function
import
BiasAdditionFunc
,
LinearBasedBiasFunc
,
func_to_func_dict
,
method_to_func_dict
from
.linear
import
Linear
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
0 → 100644
View file @
e532679c
import
operator
import
torch
import
torch.nn.functional
as
F
from
...registry
import
bias_addition_function
,
bias_addition_method
from
.bias_addition_function
import
LinearBasedBiasFunc
@
bias_addition_method
.
register
(
torch
.
Tensor
.
addbmm
)
@
bias_addition_function
.
register
(
torch
.
addbmm
)
class
Addbmm
(
LinearBasedBiasFunc
):
def
extract_kwargs_from_origin_func
(
self
):
kwargs
=
{}
if
'beta'
in
self
.
kwargs
:
kwargs
[
'beta'
]
=
self
.
kwargs
[
'beta'
]
if
'alpha'
in
self
.
kwargs
:
kwargs
[
'alpha'
]
=
self
.
kwargs
[
'alpha'
]
return
kwargs
def
create_non_bias_func_proxy
(
self
,
input_proxy
,
other_proxy
):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
assert
self
.
substitute_func
==
torch
.
bmm
node_kind
=
'call_function'
node_target
=
self
.
substitute_func
node_args
=
(
input_proxy
,
other_proxy
)
# torch.bmm does not have any kwargs
node_kwargs
=
{}
non_bias_func_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
non_bias_func_proxy
def
insert_sum_node
(
self
,
input_proxy
,
sum_dims
=
0
):
'''
This method is used to sum the input_proxy through the sum_dims.
'''
node_kind
=
'call_function'
node_target
=
torch
.
sum
node_args
=
(
input_proxy
,
sum_dims
)
node_kwargs
=
{}
sum_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
sum_proxy
def
generate
(
self
):
# The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2))
# doing the non-bias computation(temp_0 = torch.bmm(b1, b2))
non_bias_linear_func_proxy
=
self
.
create_non_bias_func_proxy
(
self
.
args
[
1
],
self
.
args
[
2
])
# doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0))
sum_proxy
=
self
.
insert_sum_node
(
non_bias_linear_func_proxy
)
kwargs
=
self
.
extract_kwargs_from_origin_func
()
if
'beta'
in
kwargs
:
beta
=
kwargs
[
'beta'
]
# doing the multiplication with beta if it exists(temp_2 = beta * input)
beta_proxy
=
self
.
create_mul_node
(
self
.
args
[
0
],
beta
)
else
:
beta_proxy
=
self
.
args
[
0
]
if
'alpha'
in
kwargs
:
alpha
=
kwargs
[
'alpha'
]
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
alpha_proxy
=
self
.
create_mul_node
(
alpha
,
sum_proxy
)
else
:
alpha_proxy
=
sum_proxy
# doing the addition(temp_4 = temp_2 + temp_3)
bias_addition_proxy
=
self
.
create_bias_addition_proxy
(
alpha_proxy
,
beta_proxy
)
return
bias_addition_proxy
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
0 → 100644
View file @
e532679c
import
operator
import
torch
import
torch.nn.functional
as
F
from
...registry
import
bias_addition_function
,
bias_addition_method
from
.bias_addition_function
import
LinearBasedBiasFunc
@
bias_addition_method
.
register
(
torch
.
Tensor
.
addmm
)
@
bias_addition_function
.
register
(
torch
.
addmm
)
class
Addmm
(
LinearBasedBiasFunc
):
def
extract_kwargs_from_origin_func
(
self
):
kwargs
=
{}
if
'beta'
in
self
.
kwargs
:
kwargs
[
'beta'
]
=
self
.
kwargs
[
'beta'
]
if
'alpha'
in
self
.
kwargs
:
kwargs
[
'alpha'
]
=
self
.
kwargs
[
'alpha'
]
return
kwargs
def
transpose_other_operand_for_linear
(
self
,
other_proxy
):
'''
This method is used to transpose the other operand for linear function.
For example:
input = torch.rand(3, 4)
m1 = torch.rand(3, 5)
m2 = torch.rand(5, 4)
original_output = torch.addmm(input, m1, m2)
# To keep the computation graph consistent with the origin computation graph, we need to transpose the m2
# before we call the linear function.
new_output = torch.linear(m1, m2.transpose(0, 1)) + input
'''
node_kind
=
'call_function'
node_target
=
torch
.
transpose
node_args
=
(
other_proxy
,
0
,
1
)
node_kwargs
=
{}
transpose_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
transpose_proxy
def
generate
(
self
):
transpose_proxy
=
self
.
transpose_other_operand_for_linear
(
self
.
args
[
2
])
non_bias_linear_func_proxy
=
self
.
create_non_bias_func_proxy
(
self
.
args
[
1
],
transpose_proxy
)
kwargs
=
self
.
extract_kwargs_from_origin_func
()
if
'beta'
in
kwargs
:
beta
=
kwargs
[
'beta'
]
beta_proxy
=
self
.
create_mul_node
(
self
.
args
[
0
],
beta
)
else
:
beta_proxy
=
self
.
args
[
0
]
if
'alpha'
in
kwargs
:
alpha
=
kwargs
[
'alpha'
]
alpha_proxy
=
self
.
create_mul_node
(
alpha
,
non_bias_linear_func_proxy
)
else
:
alpha_proxy
=
non_bias_linear_func_proxy
bias_addition_proxy
=
self
.
create_bias_addition_proxy
(
alpha_proxy
,
beta_proxy
)
return
bias_addition_proxy
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
0 → 100644
View file @
e532679c
import
operator
from
abc
import
ABC
,
abstractmethod
import
torch
import
torch.nn.functional
as
F
class
BiasAdditionFunc
(
ABC
):
"""
This class is used to construct the restructure computation graph for
call_func node with bias addition inside.
"""
def
__init__
(
self
,
tracer
,
target
,
args
,
kwargs
,
substitute_func
):
self
.
tracer
=
tracer
self
.
target
=
target
self
.
args
=
args
self
.
kwargs
=
kwargs
self
.
substitute_func
=
substitute_func
@
abstractmethod
def
extract_kwargs_from_origin_func
(
self
):
"""
This method is used to extract the kwargs for further graph transform.
For example:
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
The kwargs for addmm function is {beta=1, alpha=1, output=None}, then we need
to insert two more operator.mul nodes for the computation graph to compute the
final result.
"""
pass
@
abstractmethod
def
generate
(
self
):
"""
This method is used to construct the whole restructure computation graph for call_func node with bias
addition inside.
A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node,
a bias reshape node if needed and a bias addition node.
Use torch.addmm as an example:
The origin node is:
%addmm: call_func[target=torch.addmm](args = (%input_1, m1, m2), kwargs = {beta=1, alpha=1})
Restructured graph is:
%transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {})
%linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {})
%mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {})
%mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
"""
pass
def
create_mul_node
(
self
,
input_proxy
,
coefficent
):
"""
This method is used to create a coefficent node for the numerical correctness.
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
node_kind
=
'call_function'
node_target
=
operator
.
mul
node_args
=
(
input_proxy
,
coefficent
,
)
node_kwargs
=
{}
mul_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
mul_proxy
class
LinearBasedBiasFunc
(
BiasAdditionFunc
):
"""
This class is used to construct the restructure computation graph for
call_func node based on F.linear.
"""
def
create_non_bias_func_proxy
(
self
,
input_proxy
,
other_proxy
):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
assert
self
.
substitute_func
==
torch
.
nn
.
functional
.
linear
node_kind
=
'call_function'
node_target
=
self
.
substitute_func
node_args
=
(
input_proxy
,
other_proxy
)
# non-bias linear does not have any kwargs
node_kwargs
=
{}
non_bias_func_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
non_bias_func_proxy
def
create_bias_addition_proxy
(
self
,
non_bias_func_proxy
,
bias_proxy
):
"""
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
bias_add_node_kind
=
'call_function'
bias_add_node_target
=
operator
.
add
bias_add_args
=
(
non_bias_func_proxy
,
bias_proxy
)
bias_add_proxy
=
self
.
tracer
.
create_proxy
(
bias_add_node_kind
,
bias_add_node_target
,
tuple
(
bias_add_args
),
{})
return
bias_add_proxy
func_to_func_dict
=
{
torch
.
addmm
:
F
.
linear
,
torch
.
addbmm
:
torch
.
bmm
,
F
.
linear
:
F
.
linear
,
}
method_to_func_dict
=
{
torch
.
Tensor
.
addmm
:
F
.
linear
,
torch
.
Tensor
.
addbmm
:
torch
.
bmm
,
}
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
0 → 100644
View file @
e532679c
import
operator
import
torch
import
torch.nn.functional
as
F
from
...registry
import
bias_addition_function
from
.bias_addition_function
import
LinearBasedBiasFunc
@
bias_addition_function
.
register
(
F
.
linear
)
class
Linear
(
LinearBasedBiasFunc
):
def
extract_kwargs_from_origin_func
(
self
):
assert
'bias'
in
self
.
kwargs
kwargs
=
{}
if
'bias'
in
self
.
kwargs
:
kwargs
[
'bias'
]
=
self
.
kwargs
[
'bias'
]
return
kwargs
def
generate
(
self
):
non_bias_linear_func_proxy
=
self
.
create_non_bias_func_proxy
(
self
.
args
[
0
],
self
.
args
[
1
])
kwargs
=
self
.
extract_kwargs_from_origin_func
()
bias_addition_proxy
=
self
.
create_bias_addition_proxy
(
non_bias_linear_func_proxy
,
kwargs
[
'bias'
])
return
bias_addition_proxy
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/__init__.py
0 → 100644
View file @
e532679c
from
.bias_addition_module
import
*
from
.conv
import
*
from
.linear
import
*
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
0 → 100644
View file @
e532679c
import
operator
from
abc
import
ABC
,
abstractmethod
import
torch
import
torch.nn.functional
as
F
class
BiasAdditionModule
(
ABC
):
"""
This class is used to construct the restructure computation graph for
call_module node with bias addition inside.
"""
def
__init__
(
self
,
tracer
,
target
,
args
,
kwargs
,
substitute_func
):
self
.
tracer
=
tracer
self
.
target
=
target
self
.
args
=
args
self
.
kwargs
=
kwargs
self
.
substitute_func
=
substitute_func
self
.
weight_proxy
=
self
.
_create_weight_proxy
()
self
.
bias_proxy
=
self
.
_create_bias_proxy
()
def
_create_weight_proxy
(
self
):
"""
Create weight proxy, the node created by this proxy contains module weight.
Note: this function will be invoked during module initializing,
you should never call this function.
"""
weight_node_kind
=
'get_attr'
weight_node_target
=
self
.
target
+
'.weight'
weight_proxy
=
self
.
tracer
.
create_proxy
(
weight_node_kind
,
weight_node_target
,
(),
{})
return
weight_proxy
def
_create_bias_proxy
(
self
):
"""
Create bias proxy, the node created by this proxy contains module bias.
Note: this function will be invoked during module initializing,
you should never call this function.
"""
bias_node_kind
=
'get_attr'
bias_node_target
=
self
.
target
+
'.bias'
bias_proxy
=
self
.
tracer
.
create_proxy
(
bias_node_kind
,
bias_node_target
,
(),
{})
return
bias_proxy
@
abstractmethod
def
extract_kwargs_from_mod
(
self
):
"""
This method is used to extract the kwargs for non-bias computation.
For example:
The kwargs for conv2d module is {} because the attributes like 'padding' or 'groups' are
considered during module initilizing. However, we need to consider those attributes as kwargs
in F.conv2d.
"""
pass
def
create_non_bias_func_proxy
(
self
,
input_proxy
=
None
):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
node_kind
=
'call_function'
node_target
=
self
.
substitute_func
if
input_proxy
is
None
:
input_proxy
=
self
.
args
[
0
]
node_args
=
(
input_proxy
,
self
.
weight_proxy
)
node_kwargs
=
self
.
extract_kwargs_from_mod
()
non_bias_func_proxy
=
self
.
tracer
.
create_proxy
(
node_kind
,
node_target
,
node_args
,
node_kwargs
)
return
non_bias_func_proxy
def
create_bias_addition_proxy
(
self
,
non_bias_func_proxy
,
bias_proxy
):
"""
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
bias_add_node_kind
=
'call_function'
bias_add_node_target
=
operator
.
add
bias_add_args
=
(
non_bias_func_proxy
,
bias_proxy
)
bias_add_proxy
=
self
.
tracer
.
create_proxy
(
bias_add_node_kind
,
bias_add_node_target
,
tuple
(
bias_add_args
),
{})
return
bias_add_proxy
@
abstractmethod
def
generate
(
self
):
"""
This method is used to construct the whole restructure computation graph for call_module node with bias
addition inside.
A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node,
a bias reshape node if needed and a bias addition node.
Use Conv2d module as an example:
The origin node is:
%conv: call_module[target=conv](args = (%x,), kwargs = {})
Restructured graph is:
%conv_weight : [#users=1] = get_attr[target=conv.weight]
%conv_bias : [#users=1] = get_attr[target=conv.bias]
%conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
%view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
"""
pass
module_to_func_dict
=
{
torch
.
nn
.
Linear
:
F
.
linear
,
torch
.
nn
.
Conv1d
:
F
.
conv1d
,
torch
.
nn
.
Conv2d
:
F
.
conv2d
,
torch
.
nn
.
Conv3d
:
F
.
conv3d
,
}
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
0 → 100644
View file @
e532679c
import
torch
import
torch.nn.functional
as
F
from
torch.nn.modules.utils
import
_pair
,
_reverse_repeat_tuple
,
_single
,
_triple
from
...registry
import
bias_addition_module
from
.bias_addition_module
import
BiasAdditionModule
@
bias_addition_module
.
register
(
torch
.
nn
.
Conv1d
)
@
bias_addition_module
.
register
(
torch
.
nn
.
Conv2d
)
@
bias_addition_module
.
register
(
torch
.
nn
.
Conv3d
)
class
BiasAdditionConv
(
BiasAdditionModule
):
def
extract_kwargs_from_mod
(
self
):
root
=
self
.
tracer
.
root
conv_module
=
root
.
get_submodule
(
self
.
target
)
kwarg_attributes
=
[
'groups'
,
'dilation'
,
'stride'
]
non_bias_kwargs
=
{}
for
attr_name
in
kwarg_attributes
:
if
hasattr
(
conv_module
,
attr_name
):
non_bias_kwargs
[
attr_name
]
=
getattr
(
conv_module
,
attr_name
)
if
conv_module
.
padding_mode
!=
"zeros"
:
#TODO: non zeros mode requires some extra processing for input
conv_type
=
type
(
conv_module
)
if
conv_type
==
"torch.nn.Conv1d"
:
padding_element
=
_single
(
0
)
elif
conv_type
==
"torch.nn.Conv2d"
:
padding_element
=
_pair
(
0
)
elif
conv_type
==
"torch.nn.Conv3d"
:
padding_element
=
_triple
(
0
)
non_bias_kwargs
[
'padding'
]
=
padding_element
else
:
non_bias_kwargs
[
'padding'
]
=
getattr
(
conv_module
,
'padding'
)
return
non_bias_kwargs
def
create_bias_reshape_proxy
(
self
,
dimensions
):
"""
This method is used to reshape the bias node in order to make bias and
output of non-bias convolution broadcastable.
"""
bias_shape
=
[
1
]
*
(
dimensions
-
1
)
bias_shape
[
0
]
=
-
1
bias_reshape_node_kind
=
'call_method'
bias_reshape_node_target
=
'view'
bias_reshape_node_args
=
(
self
.
bias_proxy
,
torch
.
Size
(
bias_shape
))
bias_reshape_proxy
=
self
.
tracer
.
create_proxy
(
bias_reshape_node_kind
,
bias_reshape_node_target
,
bias_reshape_node_args
,
{})
return
bias_reshape_proxy
def
generate
(
self
):
non_bias_conv_func_proxy
=
self
.
create_non_bias_func_proxy
()
output_dims
=
non_bias_conv_func_proxy
.
meta_data
.
dim
()
bias_reshape_proxy
=
self
.
create_bias_reshape_proxy
(
output_dims
)
bias_addition_proxy
=
self
.
create_bias_addition_proxy
(
non_bias_conv_func_proxy
,
bias_reshape_proxy
)
return
bias_addition_proxy
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
0 → 100644
View file @
e532679c
import
torch
import
torch.nn.functional
as
F
from
...registry
import
bias_addition_module
from
.bias_addition_module
import
BiasAdditionModule
@
bias_addition_module
.
register
(
torch
.
nn
.
Linear
)
class
BiasAdditionLinear
(
BiasAdditionModule
):
def
extract_kwargs_from_mod
(
self
):
return
{}
def
generate
(
self
):
non_bias_linear_func_proxy
=
self
.
create_non_bias_func_proxy
()
bias_addition_proxy
=
self
.
create_bias_addition_proxy
(
non_bias_linear_func_proxy
,
self
.
bias_proxy
)
return
bias_addition_proxy
colossalai/fx/tracer/experimental.py
0 → 100644
View file @
e532679c
import
enum
import
functools
import
operator
import
inspect
from
contextlib
import
contextmanager
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch.fx
import
Graph
,
Node
,
Proxy
,
Tracer
from
torch.utils._pytree
import
tree_map
from
colossalai.fx
import
ColoGraphModule
,
compatibility
,
is_compatible_with_meta
from
colossalai.fx.tracer._tracer_utils
import
extract_meta
,
is_element_in_list
from
colossalai.fx.tracer.bias_addition_patch
import
func_to_func_dict
,
method_to_func_dict
,
module_to_func_dict
from
colossalai.fx.tracer.registry
import
(
bias_addition_function
,
bias_addition_method
,
bias_addition_module
,
meta_patched_function
,
meta_patched_module
,
)
if
is_compatible_with_meta
():
from
colossalai.fx.profiler
import
MetaTensor
Target
=
Union
[
Callable
[...,
Any
],
str
]
Argument
=
Optional
[
Union
[
Tuple
[
Any
,
...],
# actually Argument, but mypy can't represent recursive types
List
[
Any
],
# actually Argument
Dict
[
str
,
Any
],
# actually Argument
slice
,
# Slice[Argument, Argument, Argument], but slice is not a templated type in typing
'Node'
,]]
_CScriptMethod
=
[
'add'
,
'mul'
,
'sub'
,
'div'
]
_TorchNewMethod
=
[
"arange"
,
"zeros"
,
"zeros_like"
,
"ones"
,
"ones_like"
,
"full"
,
"full_like"
,
"empty"
,
"empty_like"
,
"eye"
,
"tensor"
,
"finfo"
]
_TensorPropertyMethod
=
[
"dtype"
,
"shape"
,
"device"
,
"requires_grad"
,
"grad"
,
"grad_fn"
,
"data"
]
def
_truncate_suffix
(
s
:
str
):
import
re
return
re
.
sub
(
r
'_\d+$'
,
''
,
s
)
def
default_device
():
return
torch
.
device
(
'cuda:0'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
@
compatibility
(
is_backward_compatible
=
False
)
class
ColoProxy
(
Proxy
):
def
__init__
(
self
,
*
args
,
data
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_meta_data
=
data
@
property
def
meta_data
(
self
):
return
self
.
_meta_data
@
meta_data
.
setter
def
meta_data
(
self
,
args
):
wrap_fn
=
lambda
x
:
MetaTensor
(
x
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
self
.
_meta_data
=
tree_map
(
wrap_fn
,
args
)
@
classmethod
def
__torch_function__
(
cls
,
orig_method
,
types
,
args
=
(),
kwargs
=
None
):
proxy
=
cls
.
from_torch_proxy
(
super
().
__torch_function__
(
orig_method
,
types
,
args
,
kwargs
))
unwrap_fn
=
lambda
p
:
p
.
meta_data
if
isinstance
(
p
,
ColoProxy
)
else
p
kwargs
=
{}
if
kwargs
is
None
else
kwargs
if
proxy
.
meta_data
is
None
:
proxy
.
meta_data
=
orig_method
(
*
tree_map
(
unwrap_fn
,
args
),
**
tree_map
(
unwrap_fn
,
kwargs
))
return
proxy
@
classmethod
def
from_torch_proxy
(
cls
,
proxy
:
Proxy
):
return
cls
(
proxy
.
node
,
proxy
.
tracer
)
def
__repr__
(
self
):
return
f
"ColoProxy(
{
self
.
node
.
name
}
, meta_data=
{
self
.
meta_data
}
)"
def
__len__
(
self
):
return
len
(
self
.
meta_data
)
def
__int__
(
self
):
return
int
(
self
.
meta_data
)
def
__index__
(
self
):
try
:
return
int
(
self
.
meta_data
)
except
:
return
torch
.
zeros
(
self
.
meta_data
.
shape
,
dtype
=
torch
.
bool
).
numpy
().
__index__
()
def
__float__
(
self
):
return
float
(
self
.
meta_data
)
def
__bool__
(
self
):
return
self
.
meta_data
def
__getattr__
(
self
,
k
):
return
ColoAttribute
(
self
,
k
,
getattr
(
self
.
_meta_data
,
k
,
None
))
def
__setitem__
(
self
,
key
,
value
):
proxy
=
self
.
tracer
.
create_proxy
(
'call_function'
,
operator
.
setitem
,
(
self
,
key
,
value
),
{})
proxy
.
meta_data
=
self
.
_meta_data
return
proxy
def
__contains__
(
self
,
key
):
if
self
.
node
.
op
==
"placeholder"
:
# this is used to handle like
# if x in kwargs
# we don't handle this case for now
return
False
return
super
().
__contains__
(
key
)
def
__isinstancecheck__
(
self
,
type
):
return
isinstance
(
self
.
meta_data
,
type
)
@
property
def
shape
(
self
):
return
self
.
meta_data
.
shape
@
property
def
ndim
(
self
):
return
self
.
meta_data
.
ndim
@
property
def
device
(
self
):
proxy
=
self
.
tracer
.
create_proxy
(
'call_function'
,
getattr
,
(
self
,
'device'
),
{})
proxy
.
meta_data
=
self
.
meta_data
.
device
return
proxy
@
property
def
dtype
(
self
):
proxy
=
self
.
tracer
.
create_proxy
(
'call_function'
,
getattr
,
(
self
,
'dtype'
),
{})
proxy
.
meta_data
=
self
.
meta_data
.
dtype
return
proxy
def
to
(
self
,
*
args
,
**
kwargs
):
return
self
.
tracer
.
create_proxy
(
'call_method'
,
'to'
,
(
self
,
*
args
),
{
**
kwargs
})
def
cpu
(
self
,
*
args
,
**
kwargs
):
return
self
.
tracer
.
create_proxy
(
'call_method'
,
'cpu'
,
(
self
,
*
args
),
{
**
kwargs
})
def
cuda
(
self
,
*
args
,
**
kwargs
):
return
self
.
tracer
.
create_proxy
(
'call_method'
,
'cuda'
,
(
self
,
*
args
),
{
**
kwargs
})
@
compatibility
(
is_backward_compatible
=
False
)
class
ColoAttribute
(
ColoProxy
):
def
__init__
(
self
,
root
,
attr
:
str
,
data
=
None
):
self
.
root
=
root
self
.
attr
=
attr
self
.
tracer
=
root
.
tracer
self
.
_meta_data
=
data
self
.
_node
:
Optional
[
Node
]
=
None
@
property
def
node
(
self
):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if
self
.
_node
is
None
:
self
.
_node
=
self
.
tracer
.
create_proxy
(
'call_function'
,
getattr
,
(
self
.
root
,
self
.
attr
),
{}).
node
return
self
.
_node
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
tracer
.
create_proxy
(
'call_method'
,
self
.
attr
,
(
self
.
root
,)
+
args
,
kwargs
)
def
__repr__
(
self
):
return
f
"ColoAttribute(
{
self
.
node
.
name
}
, attr=
{
self
.
attr
}
)"
@
compatibility
(
is_backward_compatible
=
False
)
class
ColoTracer
(
Tracer
):
def
__init__
(
self
,
trace_act_ckpt
:
bool
=
False
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_disable_module_getattr
=
False
self
.
proxy_buffer_attributes
=
True
# whether the tracer will record the usage of torch.utils.checkpoint
self
.
trace_act_ckpt
=
trace_act_ckpt
# whether the current tracing occurs within the activation checkpoint functions
self
.
inside_torch_checkpoint_func
=
False
self
.
act_ckpt_region_count
=
0
def
proxy
(
self
,
node
:
Node
)
->
'ColoProxy'
:
return
ColoProxy
(
node
,
self
)
def
create_proxy
(
self
,
kind
:
str
,
target
:
Target
,
args
:
Tuple
[
Any
,
...],
kwargs
:
Dict
[
str
,
Any
],
name
:
Optional
[
str
]
=
None
,
type_expr
:
Optional
[
Any
]
=
None
,
proxy_factory_fn
:
Callable
[[
Node
],
'Proxy'
]
=
None
):
proxy
:
ColoProxy
=
super
().
create_proxy
(
kind
,
target
,
args
,
kwargs
,
name
,
type_expr
,
proxy_factory_fn
)
unwrap_fn
=
lambda
p
:
p
.
meta_data
if
isinstance
(
p
,
ColoProxy
)
else
p
if
kind
==
'placeholder'
:
proxy
.
meta_data
=
self
.
meta_args
[
target
]
if
target
in
self
.
meta_args
else
self
.
concrete_args
.
get
(
_truncate_suffix
(
target
),
None
)
elif
kind
==
'get_attr'
:
self
.
_disable_module_getattr
=
True
try
:
attr_itr
=
self
.
root
atoms
=
target
.
split
(
"."
)
for
atom
in
atoms
:
attr_itr
=
getattr
(
attr_itr
,
atom
)
proxy
.
meta_data
=
attr_itr
finally
:
self
.
_disable_module_getattr
=
False
elif
kind
==
'call_function'
:
proxy
.
meta_data
=
target
(
*
tree_map
(
unwrap_fn
,
args
),
**
tree_map
(
unwrap_fn
,
kwargs
))
elif
kind
==
'call_method'
:
self
.
_disable_module_getattr
=
True
try
:
if
target
==
'__call__'
:
proxy
.
meta_data
=
unwrap_fn
(
args
[
0
])(
*
tree_map
(
unwrap_fn
,
args
[
1
:]),
**
tree_map
(
unwrap_fn
,
kwargs
))
else
:
if
target
not
in
_TensorPropertyMethod
:
proxy
.
_meta_data
=
getattr
(
unwrap_fn
(
args
[
0
]),
target
)(
*
tree_map
(
unwrap_fn
,
args
[
1
:]),
**
tree_map
(
unwrap_fn
,
kwargs
))
finally
:
self
.
_disable_module_getattr
=
False
elif
kind
==
'call_module'
:
mod
=
self
.
root
.
get_submodule
(
target
)
self
.
_disable_module_getattr
=
True
try
:
proxy
.
meta_data
=
mod
.
forward
(
*
tree_map
(
unwrap_fn
,
args
),
**
tree_map
(
unwrap_fn
,
kwargs
))
finally
:
self
.
_disable_module_getattr
=
False
return
proxy
def
create_node
(
self
,
*
args
,
**
kwargs
)
->
Node
:
node
=
super
().
create_node
(
*
args
,
**
kwargs
)
if
self
.
inside_torch_checkpoint_func
:
# annotate the activation checkpoint module
node
.
meta
[
'activation_checkpoint'
]
=
self
.
act_ckpt_region_count
return
node
def
trace
(
self
,
root
:
torch
.
nn
.
Module
,
concrete_args
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
meta_args
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
)
->
Graph
:
if
meta_args
is
None
:
meta_args
=
{}
if
concrete_args
is
None
:
concrete_args
=
{}
# check concrete and meta args have valid names
sig
=
inspect
.
signature
(
root
.
forward
)
sig_names
=
set
(
sig
.
parameters
.
keys
())
meta_arg_names
=
set
(
meta_args
.
keys
())
# update concrete args with default values
non_meta_arg_names
=
sig_names
-
meta_arg_names
for
k
,
v
in
sig
.
parameters
.
items
():
if
k
in
non_meta_arg_names
and
\
k
not
in
concrete_args
and
\
v
.
default
is
not
inspect
.
Parameter
.
empty
:
concrete_args
[
k
]
=
v
.
default
# get non concrete arg names
concrete_arg_names
=
set
(
concrete_args
.
keys
())
non_concrete_arg_names
=
sig_names
-
concrete_arg_names
def
_check_arg_name_valid
(
names
):
success
,
element
=
is_element_in_list
(
names
,
sig_names
)
if
not
success
:
raise
KeyError
(
f
"argument
{
element
}
is not found in the signature of
{
root
.
__class__
.
__name__
}
's forward function"
)
_check_arg_name_valid
(
meta_arg_names
)
_check_arg_name_valid
(
concrete_arg_names
)
self
.
concrete_args
=
concrete_args
self
.
meta_args
=
meta_args
with
_TorchTensorOverride
(
self
),
self
.
trace_activation_checkpoint
(
enabled
=
self
.
trace_act_ckpt
):
self
.
graph
=
super
().
trace
(
root
,
concrete_args
=
concrete_args
)
self
.
graph
.
lint
()
return
self
.
graph
@
contextmanager
def
trace_activation_checkpoint
(
self
,
enabled
:
bool
):
if
enabled
:
orig_ckpt_func
=
torch
.
utils
.
checkpoint
.
CheckpointFunction
class
PatchedCheckpointFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
):
# signal that the current tracing occurs within activaton checkpoint part
self
.
inside_torch_checkpoint_func
=
True
out
=
run_function
(
*
args
)
self
.
inside_torch_checkpoint_func
=
False
self
.
act_ckpt_region_count
+=
1
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
*
grad_outputs
:
Any
)
->
Any
:
raise
NotImplementedError
(
"We do not implement the backward pass as we only trace the forward pass."
)
# override the checkpoint function
torch
.
utils
.
checkpoint
.
CheckpointFunction
=
PatchedCheckpointFunction
yield
if
enabled
:
# recover the checkpoint function upon exit
torch
.
utils
.
checkpoint
.
CheckpointFunction
=
orig_ckpt_func
def
_post_check
(
self
,
non_concrete_arg_names
:
Set
[
str
]):
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
for
node
in
self
.
graph
.
nodes
:
if
node
.
op
==
"placeholder"
:
# Removing default values for inputs as the forward pass will fail with them.
if
node
.
target
in
non_concrete_arg_names
:
node
.
args
=
()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node
.
type
=
torch
.
Tensor
# It is a concrete arg so it is not used and should be removed.
else
:
if
hasattr
(
torch
.
fx
.
_symbolic_trace
,
"_assert_is_none"
):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete
=
[]
for
user
in
node
.
users
:
if
user
.
target
==
torch
.
fx
.
_symbolic_trace
.
_assert_is_none
:
to_delete
.
append
(
user
)
for
user
in
to_delete
:
self
.
graph
.
erase_node
(
user
)
self
.
graph
.
erase_node
(
node
)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if
node
.
op
==
"output"
:
node
.
type
=
None
self
.
graph
.
lint
()
def
_module_getattr
(
self
,
attr
,
attr_val
,
parameter_proxy_cache
):
if
getattr
(
self
,
"_disable_module_getattr"
,
False
):
return
attr_val
def
maybe_get_proxy_for_attr
(
attr_val
,
collection_to_search
,
parameter_proxy_cache
):
for
n
,
p
in
collection_to_search
:
if
attr_val
is
p
:
if
n
not
in
parameter_proxy_cache
:
kwargs
=
{}
if
'proxy_factory_fn'
in
inspect
.
signature
(
self
.
create_proxy
).
parameters
:
kwargs
[
'proxy_factory_fn'
]
=
(
None
if
not
self
.
param_shapes_constant
else
lambda
node
:
ColoProxy
(
self
,
node
,
n
,
attr_val
))
val_proxy
=
self
.
create_proxy
(
'get_attr'
,
n
,
(),
{},
**
kwargs
)
# type: ignore[arg-type]
parameter_proxy_cache
[
n
]
=
val_proxy
return
parameter_proxy_cache
[
n
]
return
None
if
self
.
proxy_buffer_attributes
and
isinstance
(
attr_val
,
torch
.
Tensor
):
maybe_buffer_proxy
=
maybe_get_proxy_for_attr
(
attr_val
,
self
.
root
.
named_buffers
(),
parameter_proxy_cache
)
if
maybe_buffer_proxy
is
not
None
:
return
maybe_buffer_proxy
if
isinstance
(
attr_val
,
torch
.
nn
.
Parameter
):
maybe_parameter_proxy
=
maybe_get_proxy_for_attr
(
attr_val
,
self
.
root
.
named_parameters
(),
parameter_proxy_cache
)
if
maybe_parameter_proxy
is
not
None
:
return
maybe_parameter_proxy
return
attr_val
@
compatibility
(
is_backward_compatible
=
True
)
def
symbolic_trace
(
root
:
Union
[
torch
.
nn
.
Module
,
Callable
[...,
Any
]],
concrete_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
meta_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
ColoGraphModule
:
if
is_compatible_with_meta
():
if
meta_args
is
not
None
:
root
.
to
(
default_device
())
wrap_fn
=
lambda
x
:
MetaTensor
(
x
,
fake_device
=
default_device
())
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
graph
=
ColoTracer
().
trace
(
root
,
concrete_args
=
concrete_args
,
meta_args
=
tree_map
(
wrap_fn
,
meta_args
))
root
.
cpu
()
else
:
graph
=
Tracer
().
trace
(
root
,
concrete_args
=
concrete_args
)
else
:
from
.tracer
import
ColoTracer
as
OrigColoTracer
graph
=
OrigColoTracer
().
trace
(
root
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
name
=
root
.
__class__
.
__name__
if
isinstance
(
root
,
torch
.
nn
.
Module
)
else
root
.
__name__
return
ColoGraphModule
(
root
,
graph
,
name
)
@
compatibility
(
is_backward_compatible
=
False
)
class
_TorchTensorOverride
(
object
):
def
__init__
(
self
,
tracer
:
Tracer
):
self
.
overrides
=
{}
self
.
tracer
=
tracer
def
__enter__
(
self
):
def
wrap_tensor_method
(
target
):
@
functools
.
wraps
(
target
)
def
wrapper
(
*
args
,
**
kwargs
):
is_proxy
=
any
(
isinstance
(
p
,
ColoProxy
)
for
p
in
args
)
|
any
(
isinstance
(
p
,
ColoProxy
)
for
p
in
kwargs
.
values
())
if
is_proxy
:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self
.
tracer
.
_disable_module_getattr
=
True
try
:
proxy
=
self
.
tracer
.
create_proxy
(
'call_function'
,
target
,
args
,
kwargs
)
finally
:
self
.
tracer
.
_disable_module_getattr
=
False
return
proxy
else
:
return
target
(
*
args
,
**
kwargs
)
return
wrapper
,
target
self
.
overrides
=
{
target
:
wrap_tensor_method
(
getattr
(
torch
,
target
))
for
target
in
_TorchNewMethod
if
callable
(
getattr
(
torch
,
target
))
}
for
name
,
(
wrapper
,
orig
)
in
self
.
overrides
.
items
():
setattr
(
torch
,
name
,
wrapper
)
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
for
name
,
(
wrapper
,
orig
)
in
self
.
overrides
.
items
():
setattr
(
torch
,
name
,
orig
)
def
meta_prop_pass
(
gm
:
ColoGraphModule
,
root
:
torch
.
nn
.
Module
,
meta_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
concrete_args
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
):
if
meta_args
is
None
:
meta_args
=
{}
if
concrete_args
is
None
:
concrete_args
=
{}
# check concrete and meta args have valid names
sig
=
inspect
.
signature
(
root
.
forward
)
sig_names
=
set
(
sig
.
parameters
.
keys
())
meta_arg_names
=
set
(
meta_args
.
keys
())
# update concrete args with default values
non_meta_arg_names
=
sig_names
-
meta_arg_names
for
k
,
v
in
sig
.
parameters
.
items
():
if
k
in
non_meta_arg_names
and
\
k
not
in
concrete_args
and
\
v
.
default
is
not
inspect
.
Parameter
.
empty
:
concrete_args
[
k
]
=
v
.
default
for
node
in
gm
.
graph
.
nodes
:
node
.
_meta_data
=
_meta_data_computing
(
meta_args
,
concrete_args
,
root
,
node
.
op
,
node
.
target
,
node
.
args
,
node
.
kwargs
)
def
_meta_data_computing
(
meta_args
,
concrete_args
,
root
,
kind
,
target
,
args
,
kwargs
):
unwrap_fn
=
lambda
n
:
n
.
_meta_data
if
isinstance
(
n
,
Node
)
else
n
if
kind
==
'placeholder'
:
meta_out
=
meta_args
[
target
]
if
target
in
meta_args
else
concrete_args
.
get
(
_truncate_suffix
(
target
),
None
)
elif
kind
==
'get_attr'
:
attr_itr
=
root
atoms
=
target
.
split
(
"."
)
for
atom
in
atoms
:
attr_itr
=
getattr
(
attr_itr
,
atom
)
meta_out
=
attr_itr
elif
kind
==
'call_function'
:
meta_out
=
target
(
*
tree_map
(
unwrap_fn
,
args
),
**
tree_map
(
unwrap_fn
,
kwargs
))
elif
kind
==
'call_method'
:
if
target
==
'__call__'
:
meta_out
=
unwrap_fn
(
args
[
0
])(
*
tree_map
(
unwrap_fn
,
args
[
1
:]),
**
tree_map
(
unwrap_fn
,
kwargs
))
else
:
if
target
not
in
_TensorPropertyMethod
:
meta_out
=
getattr
(
unwrap_fn
(
args
[
0
]),
target
)(
*
tree_map
(
unwrap_fn
,
args
[
1
:]),
**
tree_map
(
unwrap_fn
,
kwargs
))
elif
kind
==
'call_module'
:
mod
=
root
.
get_submodule
(
target
)
meta_out
=
mod
.
forward
(
*
tree_map
(
unwrap_fn
,
args
),
**
tree_map
(
unwrap_fn
,
kwargs
))
else
:
meta_out
=
None
return
meta_out
def
_meta_data_computing_v0
(
meta_args
,
root
,
kind
,
target
,
args
,
kwargs
):
if
kind
==
"placeholder"
and
target
in
meta_args
and
meta_args
[
target
].
is_meta
:
meta_out
=
meta_args
[
target
]
return
meta_out
if
target
in
[
getattr
(
torch
,
torch_func
)
for
torch_func
in
_TorchNewMethod
]:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if
"device"
in
kwargs
:
kwargs
[
"device"
]
=
"meta"
try
:
unwrap_fn
=
lambda
n
:
n
.
_meta_data
if
isinstance
(
n
,
Node
)
else
n
args_metas
=
tree_map
(
unwrap_fn
,
args
)
kwargs_metas
=
tree_map
(
unwrap_fn
,
kwargs
)
if
kind
==
"call_function"
:
# fetch patched function
if
meta_patched_function
.
has
(
target
):
meta_target
=
meta_patched_function
.
get
(
target
)
elif
meta_patched_function
.
has
(
target
.
__name__
):
# use name for some builtin op like @ (matmul)
meta_target
=
meta_patched_function
.
get
(
target
.
__name__
)
else
:
meta_target
=
target
meta_out
=
meta_target
(
*
args_metas
,
**
kwargs_metas
)
if
isinstance
(
meta_out
,
torch
.
Tensor
):
meta_out
=
meta_out
.
to
(
device
=
"meta"
)
elif
kind
==
"call_method"
:
method
=
getattr
(
args_metas
[
0
].
__class__
,
target
)
# fetch patched method
if
meta_patched_function
.
has
(
method
):
meta_target
=
meta_patched_function
.
get
(
method
)
else
:
meta_target
=
method
meta_out
=
meta_target
(
*
args_metas
,
**
kwargs_metas
)
elif
kind
==
"call_module"
:
mod
=
root
.
get_submodule
(
target
)
mod_type
=
type
(
mod
)
if
meta_patched_module
.
has
(
mod_type
):
meta_out
=
meta_patched_module
.
get
(
mod_type
)(
mod
,
*
args_metas
,
**
kwargs_metas
)
else
:
meta_out
=
mod
(
*
args_metas
,
**
kwargs_metas
)
elif
kind
==
"get_attr"
:
attr_itr
=
root
atoms
=
target
.
split
(
"."
)
for
atom
in
atoms
:
attr_itr
=
getattr
(
attr_itr
,
atom
)
if
isinstance
(
attr_itr
,
torch
.
nn
.
parameter
.
Parameter
):
meta_out
=
torch
.
nn
.
Parameter
(
attr_itr
.
to
(
device
=
"meta"
))
elif
isinstance
(
attr_itr
,
torch
.
Tensor
):
meta_out
=
attr_itr
.
to
(
device
=
"meta"
)
else
:
meta_out
=
attr_itr
else
:
return
None
except
Exception
as
e
:
raise
RuntimeError
(
f
"Could not compute metadata for
{
kind
}
target
{
target
}
:
{
e
}
"
)
return
meta_out
def
bias_addition_pass
(
gm
:
ColoGraphModule
,
root_model
:
torch
.
nn
.
Module
,
meta_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
result_graph
=
Graph
()
value_remap
=
{}
unwrap_fn
=
lambda
n
:
n
.
_meta_data
if
isinstance
(
n
,
Node
)
else
n
for
orig_node
in
gm
.
graph
.
nodes
:
assert
hasattr
(
orig_node
,
"_meta_data"
)
kind
=
orig_node
.
op
target
=
orig_node
.
target
args
=
orig_node
.
args
kwargs
=
orig_node
.
kwargs
args_metas
=
tree_map
(
unwrap_fn
,
args
)
tracer
=
ColoTracer
()
tracer
.
graph
=
Graph
(
tracer_cls
=
ColoTracer
)
tracer
.
root
=
root_model
def
wrap_fn
(
n
):
if
isinstance
(
n
,
Node
):
proxy
=
ColoProxy
(
n
,
tracer
)
proxy
.
meta_data
=
n
.
_meta_data
return
proxy
return
n
args_proxy
=
tree_map
(
wrap_fn
,
args
)
kwargs_proxy
=
tree_map
(
wrap_fn
,
kwargs
)
handle
=
None
if
kind
==
"call_function"
:
if
bias_addition_function
.
has
(
target
):
if
target
==
torch
.
nn
.
functional
.
linear
:
if
'bias'
in
kwargs
and
kwargs
[
'bias'
]
is
not
None
:
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
else
:
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
elif
bias_addition_function
.
has
(
target
.
__name__
):
# use name for some builtin op like @ (matmul)
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
.
__name__
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
elif
kind
==
"call_method"
:
method
=
getattr
(
args_metas
[
0
].
__class__
,
target
)
if
bias_addition_method
.
has
(
method
):
function_to_substitute
=
method_to_func_dict
[
method
]
handle
=
bias_addition_method
.
get
(
method
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
elif
kind
==
"call_module"
:
# if not hasattr(self, "orig_forward"):
# raise AttributeError(f"{self} does not have an attribute called orig_forward")
mod
=
gm
.
get_submodule
(
target
)
mod_type
=
type
(
mod
)
if
bias_addition_module
.
has
(
mod_type
)
and
mod
.
bias
is
not
None
:
function_to_substitute
=
module_to_func_dict
[
mod_type
]
handle
=
bias_addition_module
.
get
(
mod_type
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
if
handle
is
not
None
:
handle
.
generate
()
for
node_inserted
in
tracer
.
graph
.
nodes
:
value_remap
[
node_inserted
]
=
result_graph
.
node_copy
(
node_inserted
,
lambda
n
:
value_remap
[
n
])
last_node
=
value_remap
[
node_inserted
]
value_remap
[
orig_node
]
=
last_node
else
:
value_remap
[
orig_node
]
=
result_graph
.
node_copy
(
orig_node
,
lambda
n
:
value_remap
[
n
])
del
tracer
gm
.
graph
=
result_graph
gm
.
recompile
()
meta_prop_pass
(
gm
,
root_model
,
meta_args
)
colossalai/fx/tracer/meta_patch/__init__.py
View file @
e532679c
from
.registry
import
*
from
.patched_function
import
*
from
.patched_module
import
*
Prev
1
…
5
6
7
8
9
10
11
12
13
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment