Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9e768b59
Commit
9e768b59
authored
Oct 10, 2023
by
zhuwenwen
Browse files
Merge branch 'main' of
https://github.com/hpcaitech/ColossalAI
parents
7bc5a8e3
8aed02b9
Changes
442
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
574 additions
and
601 deletions
+574
-601
colossalai/__init__.py
colossalai/__init__.py
+5
-10
colossalai/_analyzer/__init__.py
colossalai/_analyzer/__init__.py
+0
-0
colossalai/_analyzer/_subclasses/_meta_registration.py
colossalai/_analyzer/_subclasses/_meta_registration.py
+108
-53
colossalai/_analyzer/_subclasses/_monkey_patch.py
colossalai/_analyzer/_subclasses/_monkey_patch.py
+1
-2
colossalai/_analyzer/_subclasses/flop_tensor.py
colossalai/_analyzer/_subclasses/flop_tensor.py
+34
-43
colossalai/_analyzer/_subclasses/meta_tensor.py
colossalai/_analyzer/_subclasses/meta_tensor.py
+22
-24
colossalai/_analyzer/fx/codegen.py
colossalai/_analyzer/fx/codegen.py
+97
-84
colossalai/_analyzer/fx/graph_module.py
colossalai/_analyzer/fx/graph_module.py
+30
-24
colossalai/_analyzer/fx/node_util.py
colossalai/_analyzer/fx/node_util.py
+28
-26
colossalai/_analyzer/fx/passes/graph_profile.py
colossalai/_analyzer/fx/passes/graph_profile.py
+56
-48
colossalai/_analyzer/fx/passes/shape_prop.py
colossalai/_analyzer/fx/passes/shape_prop.py
+20
-16
colossalai/_analyzer/fx/symbolic_profile.py
colossalai/_analyzer/fx/symbolic_profile.py
+0
-4
colossalai/_analyzer/fx/tracer/bias_addition.py
colossalai/_analyzer/fx/tracer/bias_addition.py
+102
-88
colossalai/_analyzer/fx/tracer/custom_leaf_module.py
colossalai/_analyzer/fx/tracer/custom_leaf_module.py
+1
-0
colossalai/_analyzer/fx/tracer/proxy.py
colossalai/_analyzer/fx/tracer/proxy.py
+6
-9
colossalai/_analyzer/fx/tracer/symbolic_trace.py
colossalai/_analyzer/fx/tracer/symbolic_trace.py
+6
-6
colossalai/_analyzer/fx/tracer/tracer.py
colossalai/_analyzer/fx/tracer/tracer.py
+57
-49
colossalai/amp/__init__.py
colossalai/amp/__init__.py
+0
-54
colossalai/amp/naive_amp/__init__.py
colossalai/amp/naive_amp/__init__.py
+0
-60
colossalai/amp/naive_amp/grad_scaler/__init__.py
colossalai/amp/naive_amp/grad_scaler/__init__.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
442 of 442+
files are displayed.
Plain diff
Email patch
colossalai/__init__.py
View file @
9e768b59
from
.initialize
import
(
get_default_parser
,
initialize
,
launch
,
launch_from_openmpi
,
launch_from_slurm
,
launch_from_torch
,
)
from
.initialize
import
launch
,
launch_from_openmpi
,
launch_from_slurm
,
launch_from_torch
try
:
# .version will be created by setup.py
...
...
@@ -13,5 +6,7 @@ try:
except
ModuleNotFoundError
:
# this will only happen if the user did not run `pip install`
# and directly set PYTHONPATH to use Colossal-AI which is a bad practice
__version__
=
'0.0.0'
print
(
'please install Colossal-AI from https://www.colossalai.org/download or from source'
)
__version__
=
"0.0.0"
print
(
"please install Colossal-AI from https://www.colossalai.org/download or from source"
)
__all__
=
[
"launch"
,
"launch_from_openmpi"
,
"launch_from_slurm"
,
"launch_from_torch"
,
"__version__"
]
tests/test_layers/test_2d/checks_2d
/__init__.py
→
colossalai/_analyzer
/__init__.py
View file @
9e768b59
File moved
colossalai/_analyzer/_subclasses/_meta_registration.py
View file @
9e768b59
...
...
@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Union
import
torch
from
packaging
import
version
...
...
@@ -24,25 +24,23 @@ orig_empty_like = torch.empty_like
def
new
(
*
args
,
**
kwargs
):
return
orig_empty
(
*
args
,
**
kwargs
,
device
=
torch
.
device
(
'
meta
'
))
return
orig_empty
(
*
args
,
**
kwargs
,
device
=
torch
.
device
(
"
meta
"
))
def
new_strided
(
*
args
,
**
kwargs
):
return
orig_empty_strided
(
*
args
,
**
kwargs
,
device
=
torch
.
device
(
'
meta
'
))
return
orig_empty_strided
(
*
args
,
**
kwargs
,
device
=
torch
.
device
(
"
meta
"
))
def
new_like
(
*
args
,
**
kwargs
):
return
orig_empty_like
(
*
args
,
**
kwargs
,
device
=
torch
.
device
(
'
meta
'
))
return
orig_empty_like
(
*
args
,
**
kwargs
,
device
=
torch
.
device
(
"
meta
"
))
def
register_meta
(
op
,
register_dispatcher
=
True
):
def
wrapper
(
f
):
def
add_func
(
op
):
meta_table
[
op
]
=
f
if
register_dispatcher
:
name
=
(
op
.
__name__
if
op
.
_overloadname
!=
"default"
else
op
.
overloadpacket
.
__name__
)
name
=
op
.
__name__
if
op
.
_overloadname
!=
"default"
else
op
.
overloadpacket
.
__name__
try
:
meta_lib
.
impl
(
name
,
f
)
except
:
...
...
@@ -54,7 +52,7 @@ def register_meta(op, register_dispatcher=True):
return
wrapper
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'
1.12.0
'
):
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"
1.12.0
"
):
# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834
@
register_meta
(
aten
.
convolution
.
default
)
...
...
@@ -69,7 +67,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
output_padding
:
List
[
int
],
groups
:
int
,
):
def
_formula
(
ln
:
int
,
p
:
int
,
d
:
int
,
k
:
int
,
s
:
int
)
->
int
:
"""
Formula to apply to calculate the length of some dimension of the output
...
...
@@ -146,7 +143,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
kernel_size
[
i
],
stride
[
i
],
output_padding_list
[
i
],
))
)
)
else
:
ret_shape
.
append
(
_formula
(
dims
[
i
],
padding
[
i
],
dilation
[
i
],
kernel_size
[
i
],
stride
[
i
]))
return
ret_shape
...
...
@@ -184,15 +182,35 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
return
out
@
register_meta
(
aten
.
_convolution
.
default
)
def
meta__conv
(
input_tensor
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
stride
:
List
[
int
],
padding
:
List
[
int
],
dilation
:
List
[
int
],
is_transposed
:
bool
,
output_padding
:
List
[
int
],
groups
:
int
,
*
extra_args
):
def
meta__conv
(
input_tensor
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
stride
:
List
[
int
],
padding
:
List
[
int
],
dilation
:
List
[
int
],
is_transposed
:
bool
,
output_padding
:
List
[
int
],
groups
:
int
,
*
extra_args
,
):
out
=
meta_conv
(
input_tensor
,
weight
,
bias
,
stride
,
padding
,
dilation
,
is_transposed
,
output_padding
,
groups
)
return
out
@
register_meta
(
aten
.
convolution_backward
.
default
)
def
meta_conv_backward
(
grad_output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_sizes
,
stride
,
padding
,
dilation
,
transposed
,
output_padding
,
groups
,
output_mask
):
def
meta_conv_backward
(
grad_output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_sizes
,
stride
,
padding
,
dilation
,
transposed
,
output_padding
,
groups
,
output_mask
,
):
return
new_like
(
input
),
new_like
(
weight
),
new
((
bias_sizes
))
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
...
...
@@ -224,7 +242,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
batch_sizes
,
dropout_state
,
):
is_input_packed
=
len
(
batch_sizes
)
!=
0
if
is_input_packed
:
seq_length
=
len
(
batch_sizes
)
...
...
@@ -240,8 +257,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
if
is_input_packed
:
out_shape
=
[
batch_sizes_sum
,
out_size
*
num_directions
]
else
:
out_shape
=
([
mini_batch
,
seq_length
,
out_size
*
num_directions
]
if
batch_first
else
[
seq_length
,
mini_batch
,
out_size
*
num_directions
])
out_shape
=
(
[
mini_batch
,
seq_length
,
out_size
*
num_directions
]
if
batch_first
else
[
seq_length
,
mini_batch
,
out_size
*
num_directions
]
)
output
=
input
.
new_empty
(
out_shape
)
cell_shape
=
[
num_layers
*
num_directions
,
mini_batch
,
hidden_size
]
...
...
@@ -257,15 +277,21 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@
register_meta
(
aten
.
_cudnn_rnn_backward
.
default
)
def
meta_cudnn_rnn_backward
(
input
:
torch
.
Tensor
,
def
meta_cudnn_rnn_backward
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_stride0
:
int
,
hx
:
torch
.
Tensor
,
cx
:
Optional
[
torch
.
Tensor
]
=
None
,
*
args
,
**
kwargs
):
return
new_like
(
input
),
new_like
(
weight
),
new_like
(
hx
),
new_like
(
cx
)
if
cx
is
not
None
else
new
(
())
# (grad_input, grad_weight, grad_hx, grad_cx)
**
kwargs
,
):
return
(
new_like
(
input
),
new_like
(
weight
),
new_like
(
hx
),
new_like
(
cx
)
if
cx
is
not
None
else
new
(()),
)
# (grad_input, grad_weight, grad_hx, grad_cx)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
...
...
@@ -278,7 +304,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten
.
hardtanh_backward
.
default
,
]
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'
2.0.0
'
):
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"
2.0.0
"
):
_unregistered_ewise
+=
[
aten
.
prelu_backward
.
default
,
]
...
...
@@ -296,24 +322,47 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@
register_meta
(
aten
.
native_batch_norm_backward
.
default
)
def
meta_bn_backward
(
dY
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
running_mean
,
running_var
,
save_mean
,
save_invstd
,
train
,
eps
,
output_mask
):
def
meta_bn_backward
(
dY
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
running_mean
,
running_var
,
save_mean
,
save_invstd
,
train
,
eps
,
output_mask
,
):
return
new_like
(
input
),
new_like
(
weight
),
new_like
(
weight
)
# (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@
register_meta
(
aten
.
cudnn_batch_norm
.
default
)
def
meta_cudnn_bn
(
input
:
torch
.
Tensor
,
weight
,
bias
,
running_mean
,
running_var
,
training
,
momentum
,
eps
):
n_input
=
input
.
size
(
1
)
return
new_like
(
input
),
new
((
n_input
)),
new
((
n_input
)),
new
(
(
0
),
dtype
=
torch
.
uint8
)
# (output, running_mean, running_var, reserve)
return
(
new_like
(
input
),
new
((
n_input
)),
new
((
n_input
)),
new
((
0
),
dtype
=
torch
.
uint8
),
)
# (output, running_mean, running_var, reserve)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@
register_meta
(
aten
.
cudnn_batch_norm_backward
.
default
)
def
meta_cudnn_bn_backward
(
dY
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
running_mean
,
running_var
,
save_mean
,
save_invstd
,
eps
,
reserve
):
def
meta_cudnn_bn_backward
(
dY
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
running_mean
,
running_var
,
save_mean
,
save_invstd
,
eps
,
reserve
,
):
return
new_like
(
input
),
new_like
(
weight
),
new_like
(
weight
)
# (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
...
...
@@ -324,8 +373,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@
register_meta
(
aten
.
native_layer_norm_backward
.
default
)
def
meta_ln_backward
(
dY
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
normalized_shape
,
mean
,
rstd
,
weight
,
bias
,
grad_input_mask
):
def
meta_ln_backward
(
dY
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
normalized_shape
,
mean
,
rstd
,
weight
,
bias
,
grad_input_mask
):
return
new_like
(
input
),
new_like
(
weight
),
new_like
(
bias
)
# (dX, dgamma, dbeta)
# ================================== Misc ==========================================
...
...
@@ -355,8 +405,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@
register_meta
(
aten
.
embedding_dense_backward
.
default
)
def
meta_embedding_dense_backward
(
grad_output
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
num_weights
,
padding_idx
,
scale_grad_by_freq
):
def
meta_embedding_dense_backward
(
grad_output
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
num_weights
,
padding_idx
,
scale_grad_by_freq
):
return
new
((
num_weights
,
grad_output
.
size
(
-
1
)),
dtype
=
grad_output
.
dtype
,
layout
=
grad_output
.
layout
)
# ============================== Dropout ===========================================
...
...
@@ -371,7 +422,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
def
meta_native_dropout_backward_default
(
grad
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
scale
:
float
):
return
new_like
(
grad
)
# (grad_in)
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'
1.13.0
'
):
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"
1.13.0
"
):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@
register_meta
(
aten
.
eye
.
m_out
)
def
meta_eye
(
n
:
int
,
m
:
int
,
out
:
torch
.
Tensor
):
...
...
@@ -385,24 +436,28 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
result
:
List
[
Optional
[
torch
.
Tensor
]]
=
[]
for
i
,
index
in
enumerate
(
indices
):
if
index
is
not
None
:
assert
index
.
dtype
in
[
torch
.
long
,
torch
.
int8
,
torch
.
bool
],
\
"tensors used as indices must be long, byte or bool tensors"
assert
index
.
dtype
in
[
torch
.
long
,
torch
.
int8
,
torch
.
bool
,
],
"tensors used as indices must be long, byte or bool tensors"
if
index
.
dtype
in
[
torch
.
int8
,
torch
.
bool
]:
nonzero
=
index
.
nonzero
()
k
=
len
(
result
)
assert
k
+
index
.
ndim
<=
self
.
ndim
,
f
"too many indices for tensor of dimension
{
self
.
ndim
}
"
for
j
in
range
(
index
.
ndim
):
assert
index
.
shape
[
j
]
==
self
.
shape
[
k
+
j
]
,
f
"The shape of the mask
{
index
.
shape
}
at index
{
i
}
does not match the shape of the indexed tensor
{
self
.
shape
}
at index
{
k
+
j
}
"
assert
(
index
.
shape
[
j
]
==
self
.
shape
[
k
+
j
]
)
,
f
"The shape of the mask
{
index
.
shape
}
at index
{
i
}
does not match the shape of the indexed tensor
{
self
.
shape
}
at index
{
k
+
j
}
"
result
.
append
(
nonzero
.
select
(
1
,
j
))
else
:
result
.
append
(
index
)
else
:
result
.
append
(
index
)
indices
=
result
assert
len
(
indices
)
<=
self
.
ndim
,
f
"too many indices for tensor of dimension
{
self
.
ndim
}
(got
{
len
(
indices
)
}
)"
assert
(
len
(
indices
)
<=
self
.
ndim
),
f
"too many indices for tensor of dimension
{
self
.
ndim
}
(got
{
len
(
indices
)
}
)"
# expand_outplace
import
torch._refs
as
refs
...
...
colossalai/_analyzer/_subclasses/_monkey_patch.py
View file @
9e768b59
import
torch
import
torch.distributed
as
dist
from
packaging
import
version
__all__
=
[
...
...
@@ -48,7 +47,7 @@ _DistCommMethod = [
"scatter"
,
]
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'
1.12.0
'
):
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"
1.12.0
"
):
aten
=
torch
.
ops
.
aten
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
...
...
colossalai/_analyzer/_subclasses/flop_tensor.py
View file @
9e768b59
...
...
@@ -8,7 +8,7 @@ from contextlib import contextmanager
from
enum
import
Enum
,
auto
from
functools
import
partial
,
reduce
from
numbers
import
Number
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
List
,
Union
import
torch
from
packaging
import
version
...
...
@@ -36,15 +36,15 @@ def _format_flops(flop):
B
=
1e9
T
=
1e12
if
flop
<
K
:
return
f
'
{
flop
:.
2
f
}
'
return
f
"
{
flop
:.
2
f
}
"
elif
flop
<
M
:
return
f
'
{
flop
/
K
:.
2
f
}
K
'
return
f
"
{
flop
/
K
:.
2
f
}
K
"
elif
flop
<
B
:
return
f
'
{
flop
/
M
:.
2
f
}
M
'
return
f
"
{
flop
/
M
:.
2
f
}
M
"
elif
flop
<
T
:
return
f
'
{
flop
/
B
:.
2
f
}
B
'
return
f
"
{
flop
/
B
:.
2
f
}
B
"
else
:
return
f
'
{
flop
/
T
:.
2
f
}
T
'
return
f
"
{
flop
/
T
:.
2
f
}
T
"
def
flop_count
(
module
:
Union
[
torch
.
nn
.
Module
,
Callable
]
=
None
,
*
args
,
verbose
:
bool
=
False
,
**
kwargs
)
->
Number
:
...
...
@@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
Returns:
Number: The total number of floating point operations (FWD + BWD).
"""
maybe_inplace
=
(
getattr
(
module
,
'inplace'
,
False
)
or
kwargs
.
get
(
'inplace'
,
False
)
or
getattr
(
module
,
'__name__'
,
None
)
in
(
'add_'
,
'mul_'
,
'div_'
,
'sub_'
))
maybe_inplace
=
(
getattr
(
module
,
"inplace"
,
False
)
or
kwargs
.
get
(
"inplace"
,
False
)
or
getattr
(
module
,
"__name__"
,
None
)
in
(
"add_"
,
"mul_"
,
"div_"
,
"sub_"
)
)
class
DummyModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
func
):
super
().
__init__
()
self
.
func
=
func
...
...
@@ -74,21 +76,20 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
total_flop_count
=
{
Phase
.
FWD
:
0
,
Phase
.
BWD
:
0
}
flop_counts
=
defaultdict
(
lambda
:
defaultdict
(
int
))
parents
=
[
'
Global
'
]
parents
=
[
"
Global
"
]
module
=
module
if
isinstance
(
module
,
torch
.
nn
.
Module
)
else
DummyModule
(
module
)
class
FlopTensor
(
MetaTensor
):
_tensor
:
torch
.
Tensor
def
__repr__
(
self
):
name
=
'
FlopParameter
'
if
getattr
(
self
,
'
_is_param
'
,
False
)
else
'
FlopTensor
'
name
=
"
FlopParameter
"
if
getattr
(
self
,
"
_is_param
"
,
False
)
else
"
FlopTensor
"
if
self
.
grad_fn
:
return
f
"
{
name
}
(..., size=
{
tuple
(
self
.
shape
)
}
, device='
{
self
.
device
}
', dtype=
{
self
.
dtype
}
, grad_fn=
{
self
.
grad_fn
}
)"
return
f
"
{
name
}
(..., size=
{
tuple
(
self
.
shape
)
}
, device='
{
self
.
device
}
', dtype=
{
self
.
dtype
}
)"
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
# no_dispatch is only needed if you use enable_python_mode.
# It prevents infinite recursion.
rs
=
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
...
...
@@ -115,9 +116,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return
isinstance
(
x
,
torch
.
Tensor
)
and
x
.
is_floating_point
()
def
create_backwards_push
(
name
):
class
PushState
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
*
args
):
args
=
tree_map
(
lambda
x
:
x
.
clone
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
,
args
)
...
...
@@ -134,9 +133,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return
PushState
.
apply
def
create_backwards_pop
(
name
):
class
PopState
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
*
args
):
args
=
tree_map
(
lambda
x
:
x
.
clone
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
,
args
)
...
...
@@ -147,14 +144,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
@
staticmethod
def
backward
(
ctx
,
*
grad_outs
):
nonlocal
parents
assert
(
parents
[
-
1
]
==
name
)
assert
parents
[
-
1
]
==
name
parents
.
pop
()
return
grad_outs
return
PopState
.
apply
def
enter_module
(
name
):
def
f
(
module
,
inputs
):
nonlocal
parents
parents
.
append
(
name
)
...
...
@@ -165,10 +161,9 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return
f
def
exit_module
(
name
):
def
f
(
module
,
inputs
,
outputs
):
nonlocal
parents
assert
(
parents
[
-
1
]
==
name
)
assert
parents
[
-
1
]
==
name
parents
.
pop
()
outputs
=
normalize_tuple
(
outputs
)
return
create_backwards_push
(
name
)(
*
outputs
)
...
...
@@ -189,7 +184,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
for
mod
in
flop_counts
.
keys
():
print
(
f
"Module: "
,
mod
)
for
k
,
v
in
flop_counts
[
mod
].
items
():
print
(
'
\t
'
,
k
,
_format_flops
(
v
))
print
(
"
\t
"
,
k
,
_format_flops
(
v
))
print
()
def
detach_variables
(
r
):
...
...
@@ -201,7 +196,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
def
wrap
(
r
):
if
isinstance
(
r
,
torch
.
Tensor
):
data_ptr_fn
=
getattr
(
r
,
'
_tensor
'
,
r
).
data_ptr
data_ptr_fn
=
getattr
(
r
,
"
_tensor
"
,
r
).
data_ptr
r
=
FlopTensor
(
detach_variables
(
r
))
if
maybe_inplace
:
r
=
r
+
0
...
...
@@ -375,8 +370,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
# Inputs[0] contains the shape of the input.
input_shape
=
inputs
[
input_arg_index
].
shape
has_affine
=
inputs
[
affine_arg_index
].
shape
is
not
None
if
hasattr
(
inputs
[
affine_arg_index
],
'shape'
)
else
inputs
[
affine_arg_index
]
has_affine
=
(
inputs
[
affine_arg_index
].
shape
is
not
None
if
hasattr
(
inputs
[
affine_arg_index
],
"shape"
)
else
inputs
[
affine_arg_index
]
)
assert
2
<=
len
(
input_shape
)
<=
5
,
input_shape
# 5 is just a rough estimate
flop
=
reduce
(
operator
.
mul
,
input_shape
)
*
(
5
if
has_affine
else
4
)
...
...
@@ -425,19 +423,17 @@ def zero_flop_jit(*args):
return
0
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'
1.12.0
'
):
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"
1.12.0
"
):
flop_mapping
=
{
# gemm
aten
.
mm
.
default
:
matmul_flop_jit
,
aten
.
matmul
.
default
:
matmul_flop_jit
,
aten
.
addmm
.
default
:
addmm_flop_jit
,
aten
.
bmm
.
default
:
bmm_flop_jit
,
# convolution
aten
.
convolution
.
default
:
conv_flop_jit
,
aten
.
_convolution
.
default
:
conv_flop_jit
,
aten
.
convolution_backward
.
default
:
conv_backward_flop_jit
,
# normalization
aten
.
native_batch_norm
.
default
:
batchnorm_flop_jit
,
aten
.
native_batch_norm_backward
.
default
:
batchnorm_flop_jit
,
...
...
@@ -445,7 +441,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten
.
cudnn_batch_norm_backward
.
default
:
partial
(
batchnorm_flop_jit
,
training
=
True
),
aten
.
native_layer_norm
.
default
:
norm_flop_counter
(
2
,
0
),
aten
.
native_layer_norm_backward
.
default
:
norm_flop_counter
(
2
,
0
),
# pooling
aten
.
avg_pool1d
.
default
:
ewise_flop_counter
(
1
,
0
),
aten
.
avg_pool2d
.
default
:
ewise_flop_counter
(
1
,
0
),
...
...
@@ -485,7 +480,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten
.
sum
.
default
,
aten
.
sum
.
dim_IntList
,
aten
.
mean
.
dim
,
# activation op
aten
.
hardswish
.
default
,
aten
.
hardswish_
.
default
,
...
...
@@ -509,14 +503,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten
.
tanh
.
default
,
aten
.
tanh_backward
.
default
,
aten
.
threshold_backward
.
default
,
# dropout
aten
.
native_dropout
.
default
,
aten
.
native_dropout_backward
.
default
,
# distribution
aten
.
bernoulli_
.
float
,
# where
aten
.
where
.
self
,
]
...
...
colossalai/_analyzer/_subclasses/meta_tensor.py
View file @
9e768b59
...
...
@@ -3,12 +3,12 @@ from functools import partial
import
torch
import
torch.distributed
as
dist
from
torch.types
import
_
bool
,
_device
,
_dtyp
e
from
torch.utils._pytree
import
tree_flatten
,
tree_map
from
torch.types
import
_
devic
e
from
torch.utils._pytree
import
tree_map
from
._monkey_patch
import
_AliasATen
,
_DistCommMethod
,
_InplaceATen
,
_MaybeInplaceATen
,
_TorchOverrideableFactoryMethod
__all__
=
[
'
MetaTensor
'
,
'
MetaTensorMode
'
]
__all__
=
[
"
MetaTensor
"
,
"
MetaTensorMode
"
]
def
register_storage
(
r
,
data_ptr_fn
=
None
):
...
...
@@ -28,8 +28,7 @@ def _normalize_tuple(x):
# a hack of inplace execution in PyTorch
def
_assert_alias
(
func
):
return
func
in
(
_AliasATen
+
_InplaceATen
+
_MaybeInplaceATen
# TODO: check if should be this aggressive
)
return
func
in
(
_AliasATen
+
_InplaceATen
+
_MaybeInplaceATen
)
# TODO: check if should be this aggressive
class
MetaTensor
(
torch
.
Tensor
):
...
...
@@ -65,14 +64,15 @@ class MetaTensor(torch.Tensor):
storage_offset
=
elem
.
storage_offset
(),
dtype
=
elem
.
dtype
,
layout
=
elem
.
layout
,
device
=
device
or
(
elem
.
device
if
elem
.
device
.
type
!=
'meta'
else
torch
.
device
(
'cpu'
)),
requires_grad
=
requires_grad
)
# deceive the frontend for aten selections
device
=
device
or
(
elem
.
device
if
elem
.
device
.
type
!=
"meta"
else
torch
.
device
(
"cpu"
)),
requires_grad
=
requires_grad
,
)
# deceive the frontend for aten selections
r
.
_tensor
=
elem
# ...the real tensor is held as an element on the tensor.
if
not
r
.
_tensor
.
is_meta
:
val
=
elem
.
data_ptr
()
data_ptr_fn
=
lambda
:
val
r
.
_tensor
=
r
.
_tensor
.
to
(
torch
.
device
(
'
meta
'
))
r
.
_tensor
=
r
.
_tensor
.
to
(
torch
.
device
(
"
meta
"
))
# only tensor not on `meta` should be copied to `meta`
register_storage
(
r
.
_tensor
,
data_ptr_fn
)
...
...
@@ -81,7 +81,7 @@ class MetaTensor(torch.Tensor):
return
r
def
__repr__
(
self
):
name
=
'
MetaParameter
'
if
getattr
(
self
,
'
_is_param
'
,
False
)
else
'
MetaTensor
'
name
=
"
MetaParameter
"
if
getattr
(
self
,
"
_is_param
"
,
False
)
else
"
MetaTensor
"
if
self
.
grad_fn
:
return
f
"
{
name
}
(..., size=
{
tuple
(
self
.
shape
)
}
, device='
{
self
.
device
}
', dtype=
{
self
.
dtype
}
, grad_fn=
{
self
.
grad_fn
}
)"
return
f
"
{
name
}
(..., size=
{
tuple
(
self
.
shape
)
}
, device='
{
self
.
device
}
', dtype=
{
self
.
dtype
}
)"
...
...
@@ -97,15 +97,15 @@ class MetaTensor(torch.Tensor):
x
=
x
.
_tensor
elif
isinstance
(
x
,
torch
.
Tensor
):
device
=
x
.
device
x
=
x
.
to
(
torch
.
device
(
'
meta
'
))
x
=
x
.
to
(
torch
.
device
(
"
meta
"
))
return
x
args
=
tree_map
(
unwrap
,
args
)
kwargs
=
tree_map
(
unwrap
,
kwargs
)
if
'
device
'
in
kwargs
:
device
=
kwargs
[
'
device
'
]
kwargs
[
'
device
'
]
=
torch
.
device
(
'
meta
'
)
if
"
device
"
in
kwargs
:
device
=
kwargs
[
"
device
"
]
kwargs
[
"
device
"
]
=
torch
.
device
(
"
meta
"
)
# run aten for backend=CPU but actually on backend=Meta
# here we detect whether or not the execution generates a physical copy
...
...
@@ -143,21 +143,21 @@ class MetaTensor(torch.Tensor):
nonlocal
device
if
isinstance
(
x
,
str
)
or
isinstance
(
x
,
_device
):
device
=
x
return
torch
.
device
(
'
meta
'
)
return
torch
.
device
(
"
meta
"
)
return
x
elem
=
self
.
_tensor
.
to
(
*
tree_map
(
replace
,
args
),
**
tree_map
(
replace
,
kwargs
))
return
MetaTensor
(
elem
,
device
=
device
)
def
cpu
(
self
,
*
args
,
**
kwargs
):
if
self
.
device
.
type
==
'
cpu
'
:
if
self
.
device
.
type
==
"
cpu
"
:
return
self
.
to
(
*
args
,
**
kwargs
)
return
self
.
to
(
*
args
,
device
=
'
cpu
'
,
**
kwargs
)
return
self
.
to
(
*
args
,
device
=
"
cpu
"
,
**
kwargs
)
def
cuda
(
self
,
device
=
None
,
non_blocking
=
False
):
if
device
is
not
None
:
return
self
.
to
(
device
=
device
,
non_blocking
=
non_blocking
)
return
self
.
to
(
device
=
'
cuda:0
'
,
non_blocking
=
non_blocking
)
return
self
.
to
(
device
=
"
cuda:0
"
,
non_blocking
=
non_blocking
)
def
data_ptr
(
self
):
return
self
.
_tensor
.
data_ptr
()
...
...
@@ -181,15 +181,13 @@ class MetaTensorMode(object):
self
.
dist_overrides
=
{}
# override torch.distributed.xxx
def
__enter__
(
self
):
def
_dummy
(
*
args
,
**
kwargs
):
pass
def
_new
(
*
args
,
orig_new
=
torch
.
empty
,
**
kwargs
):
return
MetaTensor
(
orig_new
(
*
args
,
**
{
**
kwargs
,
'device'
:
'meta'
}),
device
=
kwargs
.
get
(
'device'
,
torch
.
device
(
'cpu'
)))
return
MetaTensor
(
orig_new
(
*
args
,
**
{
**
kwargs
,
"device"
:
"meta"
}),
device
=
kwargs
.
get
(
"device"
,
torch
.
device
(
"cpu"
))
)
for
func
in
_TorchOverrideableFactoryMethod
:
self
.
torch_overrides
[
func
]
=
getattr
(
torch
,
func
)
...
...
colossalai/_analyzer/fx/codegen.py
View file @
9e768b59
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Tuple
import
torch
...
...
@@ -22,7 +22,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a
import
colossalai
from
colossalai.fx._compatibility
import
compatibility
_register_custom_builtin
(
'
colossalai
'
,
'
import colossalai
'
,
colossalai
)
_register_custom_builtin
(
"
colossalai
"
,
"
import colossalai
"
,
colossalai
)
def
_gen_ckpt_fn_def
(
label
,
free_vars
:
List
[
str
])
->
str
:
...
...
@@ -43,17 +43,17 @@ def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True):
"""
Generate the checkpoint function call code text
"""
outputs
=
'
,
'
.
join
(
output_vars
)
inputs
=
'
,
'
.
join
(
input_vars
)
return
f
'
{
outputs
}
= torch.utils.checkpoint.checkpoint(self.checkpoint_
{
label
}
,
{
inputs
}
, use_reentrant=
{
use_reentrant
}
)
'
outputs
=
"
,
"
.
join
(
output_vars
)
inputs
=
"
,
"
.
join
(
input_vars
)
return
f
"
{
outputs
}
= torch.utils.checkpoint.checkpoint(self.checkpoint_
{
label
}
,
{
inputs
}
, use_reentrant=
{
use_reentrant
}
)
"
def
_end_of_ckpt
(
node
:
Node
,
ckpt_level
:
int
)
->
bool
:
"""
Check if the node could end the ckpt region at `ckpt_level`
"""
if
len
(
node
.
meta
[
'
info
'
].
activation_checkpoint
)
>
ckpt_level
:
return
node
.
meta
[
'
info
'
].
activation_checkpoint
[
ckpt_level
]
is
not
None
if
len
(
node
.
meta
[
"
info
"
].
activation_checkpoint
)
>
ckpt_level
:
return
node
.
meta
[
"
info
"
].
activation_checkpoint
[
ckpt_level
]
is
not
None
return
True
...
...
@@ -94,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
current_region
=
None
for
idx
,
node
in
enumerate
(
node_list
):
if
len
(
node
.
meta
[
'
info
'
].
activation_checkpoint
)
>
ckpt_level
:
act_ckpt_label
=
node
.
meta
[
'
info
'
].
activation_checkpoint
[
ckpt_level
]
if
len
(
node
.
meta
[
"
info
"
].
activation_checkpoint
)
>
ckpt_level
:
act_ckpt_label
=
node
.
meta
[
"
info
"
].
activation_checkpoint
[
ckpt_level
]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
...
...
@@ -131,13 +131,9 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
return
ckpt_regions
def
emit_ckpt_func
(
body
,
ckpt_func
,
node_list
:
List
[
Node
],
emit_node_func
,
delete_unused_value_func
,
ckpt_level
=
0
,
in_ckpt
=
False
):
def
emit_ckpt_func
(
body
,
ckpt_func
,
node_list
:
List
[
Node
],
emit_node_func
,
delete_unused_value_func
,
ckpt_level
=
0
,
in_ckpt
=
False
):
"""Emit ckpt function in nested way
Args:
...
...
@@ -156,12 +152,12 @@ def emit_ckpt_func(body,
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
# the label will be '0_1_1'
label
=
"_"
.
join
([
str
(
idx
)
for
idx
in
node_list
[
0
].
meta
[
'
info
'
].
activation_checkpoint
[:
ckpt_level
+
1
]])
label
=
"_"
.
join
([
str
(
idx
)
for
idx
in
node_list
[
0
].
meta
[
"
info
"
].
activation_checkpoint
[:
ckpt_level
+
1
]])
ckpt_fn_def
=
_gen_ckpt_fn_def
(
label
,
inputs
)
ckpt_func
.
append
(
f
'
{
ckpt_fn_def
}
\n
'
)
ckpt_func
.
append
(
f
"
{
ckpt_fn_def
}
\n
"
)
# if there is more level to fetch
if
ckpt_level
+
1
<
max
(
map
(
lambda
node
:
len
(
node
.
meta
[
'
info
'
].
activation_checkpoint
),
node_list
)):
if
ckpt_level
+
1
<
max
(
map
(
lambda
node
:
len
(
node
.
meta
[
"
info
"
].
activation_checkpoint
),
node_list
)):
ckpt_regions
=
_find_nested_ckpt_regions
(
node_list
,
ckpt_level
+
1
)
start_idx
=
[
item
[
0
]
for
item
in
ckpt_regions
]
end_idx
=
[
item
[
1
]
for
item
in
ckpt_regions
]
...
...
@@ -174,33 +170,40 @@ def emit_ckpt_func(body,
break
if
node_idx
in
start_idx
:
ckpt_node_list
=
node_list
[
node_idx
:
end_idx
[
start_idx
.
index
(
node_idx
)]
+
1
]
emit_ckpt_func
(
ckpt_func
,
ckpt_func_buffer
,
ckpt_node_list
,
emit_node_func
,
delete_unused_value_func
,
ckpt_level
+
1
,
True
)
ckpt_node_list
=
node_list
[
node_idx
:
end_idx
[
start_idx
.
index
(
node_idx
)]
+
1
]
emit_ckpt_func
(
ckpt_func
,
ckpt_func_buffer
,
ckpt_node_list
,
emit_node_func
,
delete_unused_value_func
,
ckpt_level
+
1
,
True
,
)
node_idx
+=
len
(
ckpt_node_list
)
else
:
node
=
node_list
[
node_idx
]
emit_node_func
(
node
,
ckpt_func
)
ckpt_func
[
-
1
]
=
'
'
+
ckpt_func
[
-
1
]
ckpt_func
[
-
1
]
=
"
"
+
ckpt_func
[
-
1
]
delete_unused_value_func
(
node
,
ckpt_func
)
node_idx
+=
1
ckpt_func
.
append
(
'
'
+
_gen_ckpt_output
(
outputs
)
+
'
\n\n
'
)
ckpt_func
.
append
(
"
"
+
_gen_ckpt_output
(
outputs
)
+
"
\n\n
"
)
ckpt_func
+=
ckpt_func_buffer
# last level
else
:
for
node
in
node_list
:
emit_node_func
(
node
,
ckpt_func
)
ckpt_func
[
-
1
]
=
'
'
+
ckpt_func
[
-
1
]
ckpt_func
[
-
1
]
=
"
"
+
ckpt_func
[
-
1
]
delete_unused_value_func
(
node
,
ckpt_func
)
ckpt_func
.
append
(
'
'
+
_gen_ckpt_output
(
outputs
)
+
'
\n\n
'
)
ckpt_func
.
append
(
"
"
+
_gen_ckpt_output
(
outputs
)
+
"
\n\n
"
)
usage
=
_gen_ckpt_usage
(
label
,
inputs
,
outputs
,
False
)
+
'
\n
'
usage
=
_gen_ckpt_usage
(
label
,
inputs
,
outputs
,
False
)
+
"
\n
"
if
in_ckpt
:
usage
=
'
'
+
usage
usage
=
"
"
+
usage
body
.
append
(
usage
)
...
...
@@ -229,7 +232,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# process ckpt_regions
if
node_idx
in
start_idx
:
ckpt_node_list
=
node_list
[
node_idx
:
end_idx
[
start_idx
.
index
(
node_idx
)]
+
1
]
ckpt_node_list
=
node_list
[
node_idx
:
end_idx
[
start_idx
.
index
(
node_idx
)]
+
1
]
emit_ckpt_func
(
body
,
ckpt_func
,
ckpt_node_list
,
emit_node_func
,
delete_unused_value_func
)
node_idx
+=
len
(
ckpt_node_list
)
...
...
@@ -243,7 +246,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
@
compatibility
(
is_backward_compatible
=
True
)
class
ActivationCheckpointCodeGen
(
CodeGen
):
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
free_vars
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
...
...
@@ -251,7 +253,7 @@ class ActivationCheckpointCodeGen(CodeGen):
wrapped_fns
:
Dict
[
str
,
None
]
=
{}
# Wrap string in list to pass by reference
maybe_return_annotation
:
List
[
str
]
=
[
''
]
maybe_return_annotation
:
List
[
str
]
=
[
""
]
def
add_global
(
name_hint
:
str
,
obj
:
Any
):
"""Add an obj to be tracked as a global.
...
...
@@ -281,16 +283,16 @@ class ActivationCheckpointCodeGen(CodeGen):
def
type_repr
(
o
:
Any
):
if
o
==
():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return
'
()
'
return
"
()
"
typename
=
_type_repr
(
o
)
if
hasattr
(
o
,
'
__origin__
'
):
if
hasattr
(
o
,
"
__origin__
"
):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type
=
_origin_type_map
.
get
(
o
.
__origin__
,
o
.
__origin__
)
origin_typename
=
add_global
(
_type_repr
(
origin_type
),
origin_type
)
if
hasattr
(
o
,
'
__args__
'
):
if
hasattr
(
o
,
"
__args__
"
):
# Assign global names for each of the inner type variables.
args
=
[
type_repr
(
arg
)
for
arg
in
o
.
__args__
]
...
...
@@ -309,19 +311,18 @@ class ActivationCheckpointCodeGen(CodeGen):
return
add_global
(
typename
,
o
)
def
_format_args
(
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Argument
])
->
str
:
def
_get_repr
(
arg
):
# Handle NamedTuples (if it has `_fields`) via add_global.
if
isinstance
(
arg
,
tuple
)
and
hasattr
(
arg
,
'
_fields
'
):
if
isinstance
(
arg
,
tuple
)
and
hasattr
(
arg
,
"
_fields
"
):
qualified_name
=
_get_qualified_name
(
type
(
arg
))
global_name
=
add_global
(
qualified_name
,
type
(
arg
))
return
f
"
{
global_name
}{
repr
(
tuple
(
arg
))
}
"
return
repr
(
arg
)
args_s
=
'
,
'
.
join
(
_get_repr
(
a
)
for
a
in
args
)
kwargs_s
=
'
,
'
.
join
(
f
'
{
k
}
=
{
_get_repr
(
v
)
}
'
for
k
,
v
in
kwargs
.
items
())
args_s
=
"
,
"
.
join
(
_get_repr
(
a
)
for
a
in
args
)
kwargs_s
=
"
,
"
.
join
(
f
"
{
k
}
=
{
_get_repr
(
v
)
}
"
for
k
,
v
in
kwargs
.
items
())
if
args_s
and
kwargs_s
:
return
f
'
{
args_s
}
,
{
kwargs_s
}
'
return
f
"
{
args_s
}
,
{
kwargs_s
}
"
return
args_s
or
kwargs_s
# Run through reverse nodes and record the first instance of a use
...
...
@@ -347,82 +348,94 @@ class ActivationCheckpointCodeGen(CodeGen):
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if
user
.
op
==
'
placeholder
'
:
if
user
.
op
==
"
placeholder
"
:
return
if
user
.
op
==
'
output
'
:
body
.
append
(
'
\n
'
)
if
user
.
op
==
"
output
"
:
body
.
append
(
"
\n
"
)
return
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
len
(
nodes_to_delete
):
to_delete_str
=
'
=
'
.
join
([
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
'
None
'
])
body
.
append
(
f
'
;
{
to_delete_str
}
\n
'
)
to_delete_str
=
"
=
"
.
join
([
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
"
None
"
])
body
.
append
(
f
"
;
{
to_delete_str
}
\n
"
)
else
:
body
.
append
(
'
\n
'
)
body
.
append
(
"
\n
"
)
# NOTE: we add a variable to distinguish body and ckpt_func
def
emit_node
(
node
:
Node
,
body
):
maybe_type_annotation
=
''
if
node
.
type
is
None
else
f
'
:
{
type_repr
(
node
.
type
)
}
'
if
node
.
op
==
'
placeholder
'
:
maybe_type_annotation
=
""
if
node
.
type
is
None
else
f
"
:
{
type_repr
(
node
.
type
)
}
"
if
node
.
op
==
"
placeholder
"
:
assert
isinstance
(
node
.
target
,
str
)
maybe_default_arg
=
''
if
not
node
.
args
else
f
'
=
{
repr
(
node
.
args
[
0
])
}
'
free_vars
.
append
(
f
'
{
node
.
target
}{
maybe_type_annotation
}{
maybe_default_arg
}
'
)
raw_name
=
node
.
target
.
replace
(
'*'
,
''
)
maybe_default_arg
=
""
if
not
node
.
args
else
f
"
=
{
repr
(
node
.
args
[
0
])
}
"
free_vars
.
append
(
f
"
{
node
.
target
}{
maybe_type_annotation
}{
maybe_default_arg
}
"
)
raw_name
=
node
.
target
.
replace
(
"*"
,
""
)
if
raw_name
!=
repr
(
node
):
body
.
append
(
f
'
{
repr
(
node
)
}
=
{
raw_name
}
\n
'
)
body
.
append
(
f
"
{
repr
(
node
)
}
=
{
raw_name
}
\n
"
)
return
elif
node
.
op
==
'
call_method
'
:
elif
node
.
op
==
"
call_method
"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
'
f
'(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)'
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
"
f
"(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)"
)
return
elif
node
.
op
==
'
call_function
'
:
elif
node
.
op
==
"
call_function
"
:
assert
callable
(
node
.
target
)
# pretty print operators
if
node
.
target
.
__module__
==
'
_operator
'
and
node
.
target
.
__name__
in
magic_methods
:
if
node
.
target
.
__module__
==
"
_operator
"
and
node
.
target
.
__name__
in
magic_methods
:
assert
isinstance
(
node
.
args
,
tuple
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
= '
f
'
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
'
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
"
)
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if
node
.
target
.
__module__
==
'_operator'
and
node
.
target
.
__name__
in
inplace_methods
:
body
.
append
(
f
'
{
inplace_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
; '
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
repr
(
node
.
args
[
0
])
}
'
)
if
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
inplace_methods
:
body
.
append
(
f
"
{
inplace_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
; "
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
repr
(
node
.
args
[
0
])
}
"
)
return
qualified_name
=
_get_qualified_name
(
node
.
target
)
global_name
=
add_global
(
qualified_name
,
node
.
target
)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if
global_name
==
'getattr'
and
\
isinstance
(
node
.
args
,
tuple
)
and
\
isinstance
(
node
.
args
[
1
],
str
)
and
\
node
.
args
[
1
].
isidentifier
()
and
\
len
(
node
.
args
)
==
2
:
if
(
global_name
==
"getattr"
and
isinstance
(
node
.
args
,
tuple
)
and
isinstance
(
node
.
args
[
1
],
str
)
and
node
.
args
[
1
].
isidentifier
()
and
len
(
node
.
args
)
==
2
):
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
args
[
1
])
}
'
)
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
args
[
1
])
}
"
)
return
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
global_name
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)'
)
if
node
.
meta
.
get
(
'is_wrapped'
,
False
):
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
global_name
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
if
node
.
meta
.
get
(
"is_wrapped"
,
False
):
wrapped_fns
.
setdefault
(
global_name
)
return
elif
node
.
op
==
'
call_module
'
:
elif
node
.
op
==
"
call_module
"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
= '
f
'
{
_format_target
(
root_module
,
node
.
target
)
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)'
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
_format_target
(
root_module
,
node
.
target
)
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
return
elif
node
.
op
==
'
get_attr
'
:
elif
node
.
op
==
"
get_attr
"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
root_module
,
node
.
target
)
}
'
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
root_module
,
node
.
target
)
}
"
)
return
elif
node
.
op
==
'
output
'
:
elif
node
.
op
==
"
output
"
:
if
node
.
type
is
not
None
:
maybe_return_annotation
[
0
]
=
f
" ->
{
type_repr
(
node
.
type
)
}
"
body
.
append
(
self
.
generate_output
(
node
.
args
[
0
]))
return
raise
NotImplementedError
(
f
'
node:
{
node
.
op
}
{
node
.
target
}
'
)
raise
NotImplementedError
(
f
"
node:
{
node
.
op
}
{
node
.
target
}
"
)
# Modified for activation checkpointing
ckpt_func
=
[]
...
...
@@ -432,13 +445,13 @@ class ActivationCheckpointCodeGen(CodeGen):
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body
.
append
(
'
pass
\n
'
)
body
.
append
(
"
pass
\n
"
)
if
len
(
wrapped_fns
)
>
0
:
wrap_name
=
add_global
(
'
wrap
'
,
torch
.
fx
.
wrap
)
wrap_stmts
=
'
\n
'
.
join
([
f
'
{
wrap_name
}
("
{
name
}
")'
for
name
in
wrapped_fns
])
wrap_name
=
add_global
(
"
wrap
"
,
torch
.
fx
.
wrap
)
wrap_stmts
=
"
\n
"
.
join
([
f
'
{
wrap_name
}
("
{
name
}
")'
for
name
in
wrapped_fns
])
else
:
wrap_stmts
=
''
wrap_stmts
=
""
if
self
.
_body_transformer
:
body
=
self
.
_body_transformer
(
body
)
...
...
@@ -447,11 +460,11 @@ class ActivationCheckpointCodeGen(CodeGen):
add_global
(
name
,
value
)
prologue
=
self
.
gen_fn_def
(
free_vars
,
maybe_return_annotation
[
0
])
prologue
=
''
.
join
(
ckpt_func
)
+
prologue
prologue
=
""
.
join
(
ckpt_func
)
+
prologue
prologue
=
prologue
code
=
''
.
join
(
body
)
code
=
'
\n
'
.
join
(
'
'
+
line
for
line
in
code
.
split
(
'
\n
'
))
code
=
""
.
join
(
body
)
code
=
"
\n
"
.
join
(
"
"
+
line
for
line
in
code
.
split
(
"
\n
"
))
fn_code
=
f
"""
{
wrap_stmts
}
{
prologue
}
...
...
colossalai/_analyzer/fx/graph_module.py
View file @
9e768b59
...
...
@@ -13,6 +13,7 @@ from torch.fx.graph import PythonCode
try
:
from
torch.fx.graph
import
_PyTreeCodeGen
SUPPORT_PT_CODEGEN
=
True
except
ImportError
:
SUPPORT_PT_CODEGEN
=
False
...
...
@@ -24,7 +25,6 @@ from torch.nn.modules.module import _addindent
# This is a copy of torch.fx.graph_module._WrappedCall.
# It should be removed when we stop supporting torch < 1.12.0.
class
_WrappedCall
:
def
__init__
(
self
,
cls
,
cls_call
):
self
.
cls
=
cls
self
.
cls_call
=
cls_call
...
...
@@ -50,12 +50,14 @@ class _WrappedCall:
# constituent substrings of the error message
tb_repr
=
traceback
.
format_exc
()
custom_msg
=
(
"Call using an FX-traced Module, "
custom_msg
=
(
"Call using an FX-traced Module, "
f
"line
{
err_lineno
}
of the traced Module's "
"generated forward function:"
)
before_err
=
""
.
join
(
all_src_lines
[
err_lineno
-
2
:
err_lineno
])
"generated forward function:"
)
before_err
=
""
.
join
(
all_src_lines
[
err_lineno
-
2
:
err_lineno
])
marker
=
"~"
*
err_line_len
+
"~~~ <--- HERE"
err_and_after_err
=
"
\n
"
.
join
(
all_src_lines
[
err_lineno
:
err_lineno
+
2
])
err_and_after_err
=
"
\n
"
.
join
(
all_src_lines
[
err_lineno
:
err_lineno
+
2
])
# joined message
return
"
\n
"
.
join
([
tb_repr
,
custom_msg
,
before_err
,
marker
,
err_and_after_err
])
...
...
@@ -68,8 +70,11 @@ class _WrappedCall:
return
super
(
self
.
cls
,
obj
).
__call__
(
*
args
,
**
kwargs
)
# type: ignore[misc]
except
Exception
as
e
:
assert
e
.
__traceback__
topmost_framesummary
:
traceback
.
FrameSummary
=
\
traceback
.
StackSummary
.
extract
(
traceback
.
walk_tb
(
e
.
__traceback__
))[
-
1
]
# type: ignore[arg-type]
topmost_framesummary
:
traceback
.
FrameSummary
=
traceback
.
StackSummary
.
extract
(
traceback
.
walk_tb
(
e
.
__traceback__
)
)[
-
1
]
# type: ignore[arg-type]
if
"eval_with_key"
in
topmost_framesummary
.
filename
:
print
(
_WrappedCall
.
_generate_error_message
(
topmost_framesummary
),
file
=
sys
.
stderr
)
raise
e
.
with_traceback
(
None
)
...
...
@@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule):
code.
"""
def
__init__
(
self
,
root
:
Union
[
torch
.
nn
.
Module
,
Dict
[
str
,
Any
]],
graph
:
torch
.
fx
.
Graph
,
class_name
:
str
=
'GraphModule'
):
def
__init__
(
self
,
root
:
Union
[
torch
.
nn
.
Module
,
Dict
[
str
,
Any
]],
graph
:
torch
.
fx
.
Graph
,
class_name
:
str
=
"GraphModule"
):
super
().
__init__
(
root
,
graph
,
class_name
)
def
bind
(
self
,
ckpt_def
,
globals
):
...
...
@@ -134,7 +138,7 @@ class ColoGraphModule(torch.fx.GraphModule):
if
SUPPORT_PT_CODEGEN
and
isinstance
(
self
.
_graph
.
_codegen
,
_PyTreeCodeGen
):
self
.
_in_spec
=
self
.
_graph
.
_codegen
.
pytree_info
.
in_spec
self
.
_out_spec
=
self
.
_graph
.
_codegen
.
pytree_info
.
out_spec
python_code
=
self
.
_graph
.
python_code
(
root_module
=
'
self
'
)
python_code
=
self
.
_graph
.
python_code
(
root_module
=
"
self
"
)
self
.
_code
=
python_code
.
src
# To split ckpt functions code and forward code
...
...
@@ -157,7 +161,7 @@ class ColoGraphModule(torch.fx.GraphModule):
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call
=
cls
.
__call__
if
"__call__"
in
vars
(
cls
)
else
None
if
'
_wrapped_call
'
not
in
vars
(
cls
):
if
"
_wrapped_call
"
not
in
vars
(
cls
):
cls
.
_wrapped_call
=
_WrappedCall
(
cls
,
cls_call
)
# type: ignore[attr-defined]
def
call_wrapped
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -182,7 +186,7 @@ class ColoGraphModule(torch.fx.GraphModule):
"""
folder
=
Path
(
folder
)
Path
(
folder
).
mkdir
(
exist_ok
=
True
)
torch
.
save
(
self
.
state_dict
(),
folder
/
'
state_dict.pt
'
)
torch
.
save
(
self
.
state_dict
(),
folder
/
"
state_dict.pt
"
)
tab
=
" "
*
4
# we add import colossalai here
...
...
@@ -208,10 +212,10 @@ class {module_name}(torch.nn.Module):
for
module_name
,
module
in
self
.
named_children
():
module_str
=
_gen_model_repr
(
module_name
,
module
)
if
module_str
is
None
:
module_file
=
folder
/
f
'
{
module_name
}
.pt
'
module_file
=
folder
/
f
"
{
module_name
}
.pt
"
torch
.
save
(
module
,
module_file
)
blobified_modules
.
append
(
module_name
)
module_repr
=
module
.
__repr__
().
replace
(
'
\r
'
,
' '
).
replace
(
'
\n
'
,
' '
)
module_repr
=
module
.
__repr__
().
replace
(
"
\r
"
,
" "
).
replace
(
"
\n
"
,
" "
)
module_str
=
f
"torch.load(r'
{
module_file
}
') #
{
module_repr
}
"
model_str
+=
f
"
{
tab
*
2
}
self.
{
module_name
}
=
{
module_str
}
\n
"
...
...
@@ -228,12 +232,14 @@ class {module_name}(torch.nn.Module):
model_str
+=
f
"
{
tab
*
2
}
self.load_state_dict(torch.load(r'
{
folder
}
/state_dict.pt'))
\n
"
model_str
+=
f
"
{
_addindent
(
self
.
code
,
4
)
}
\n
"
module_file
=
folder
/
'
module.py
'
module_file
=
folder
/
"
module.py
"
module_file
.
write_text
(
model_str
)
init_file
=
folder
/
'
__init__.py
'
init_file
.
write_text
(
'
from .module import *
'
)
init_file
=
folder
/
"
__init__.py
"
init_file
.
write_text
(
"
from .module import *
"
)
if
len
(
blobified_modules
)
>
0
:
warnings
.
warn
(
"Was not able to save the following children modules as reprs -"
f
"saved as pickled files instead:
{
blobified_modules
}
"
)
warnings
.
warn
(
"Was not able to save the following children modules as reprs -"
f
"saved as pickled files instead:
{
blobified_modules
}
"
)
colossalai/_analyzer/fx/node_util.py
View file @
9e768b59
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
,
ClassVar
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch.autograd.profiler_util
import
_format_memory
,
_format_time
from
torch.fx
import
Graph
,
GraphModule
,
Node
from
torch.autograd.profiler_util
import
_format_memory
from
torch.fx
import
Node
from
colossalai._analyzer.envs
import
MeshConfig
...
...
@@ -85,7 +85,7 @@ class MetaInfo:
node
:
Node
# directory
mod_dir
:
str
=
''
mod_dir
:
str
=
""
# ctx[data_ptr] = Tensor
# mark the storage for ctx.save_for_backward
...
...
@@ -114,27 +114,27 @@ class MetaInfo:
# ============================= Invariant ==================================
activation_checkpoint
:
Tuple
[
torch
.
Tensor
]
=
()
# (region_0, region_1, ...) support nested codegen
to_offload
:
Optional
[
bool
]
=
False
sharding_spec
:
str
=
'
RR
'
sharding_spec
:
str
=
"
RR
"
def
__new__
(
cls
,
node
:
Node
,
**
kwargs
):
orig_init
=
cls
.
__init__
# if initialized, return the existing one
# should disable the __init__ function
if
node
.
meta
.
get
(
'
info
'
,
None
)
is
not
None
:
if
node
.
meta
.
get
(
"
info
"
,
None
)
is
not
None
:
def
_dummy
(
self
,
*
args
,
**
kwargs
):
if
getattr
(
self
,
'
_is_init
'
,
False
):
if
getattr
(
self
,
"
_is_init
"
,
False
):
self
.
_is_init
=
True
orig_init
(
self
,
*
args
,
**
kwargs
)
cls
.
__init__
=
orig_init
cls
.
__init__
=
_dummy
return
node
.
meta
[
'
info
'
]
return
node
.
meta
[
"
info
"
]
return
super
().
__new__
(
cls
)
def
__post_init__
(
self
):
self
.
node
.
meta
[
'
info
'
]
=
self
self
.
node
.
meta
[
"
info
"
]
=
self
@
property
def
fwd_time
(
self
,
tflops
:
float
=
MeshConfig
.
TFLOPS
,
bandwidth
:
float
=
MeshConfig
.
BANDWIDTH
):
...
...
@@ -188,24 +188,26 @@ class MetaInfo:
return
compute_size_in_bytes
(
self
.
inputs
)
def
__repr__
(
self
):
s
=
f
'
Node
{
self
.
node
.
name
}
'
s
=
f
"
Node
{
self
.
node
.
name
}
"
if
self
.
parameters
:
s
+=
f
'
\n\t
has parameter of size
{
_format_memory
(
self
.
param_size
)
}
'
s
+=
f
"
\n\t
has parameter of size
{
_format_memory
(
self
.
param_size
)
}
"
if
self
.
buffers
:
s
+=
f
'
\n\t
has buffer of size
{
_format_memory
(
self
.
buffer_size
)
}
'
s
+=
f
"
\n\t
has buffer of size
{
_format_memory
(
self
.
buffer_size
)
}
"
if
self
.
output_size
:
s
+=
f
'
\n\t
has output activation of size
{
_format_memory
(
self
.
output_size
)
}
'
s
+=
f
"
\n\t
has output activation of size
{
_format_memory
(
self
.
output_size
)
}
"
# if self.total_size:
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
if
self
.
temp_size
:
s
+=
f
'
\n\t
has temp activation of size
{
_format_memory
(
self
.
temp_size
)
}
'
s
+=
f
"
\n\t
has temp activation of size
{
_format_memory
(
self
.
temp_size
)
}
"
if
self
.
backward_size
:
s
+=
f
'
\n\t
has backward activation of size
{
_format_memory
(
self
.
backward_size
)
}
'
s
+=
f
'
\n\t
fwd_flop =
{
self
.
fwd_flop
}
'
\
f
'
\n\t
bwd_flop =
{
self
.
bwd_flop
}
'
\
f
'
\n\t
fwd_comm =
{
self
.
fwd_comm
}
'
\
f
'
\n\t
bwd_comm =
{
self
.
bwd_comm
}
'
\
f
'
\n\t
to_recompute =
{
self
.
to_recompute
}
'
\
f
'
\n\t
to_offload =
{
self
.
to_offload
}
'
\
f
'
\n\t
sharding_spec =
{
self
.
sharding_spec
}
'
s
+=
f
"
\n\t
has backward activation of size
{
_format_memory
(
self
.
backward_size
)
}
"
s
+=
(
f
"
\n\t
fwd_flop =
{
self
.
fwd_flop
}
"
f
"
\n\t
bwd_flop =
{
self
.
bwd_flop
}
"
f
"
\n\t
fwd_comm =
{
self
.
fwd_comm
}
"
f
"
\n\t
bwd_comm =
{
self
.
bwd_comm
}
"
f
"
\n\t
to_recompute =
{
self
.
to_recompute
}
"
f
"
\n\t
to_offload =
{
self
.
to_offload
}
"
f
"
\n\t
sharding_spec =
{
self
.
sharding_spec
}
"
)
return
s
colossalai/_analyzer/fx/passes/graph_profile.py
View file @
9e768b59
from
typing
import
Any
,
Dict
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterator
,
List
,
Optional
,
Tuple
import
torch
import
torch.fx
from
torch.autograd.profiler_util
import
_format_memory
,
_format_time
from
torch.autograd.profiler_util
import
_format_memory
from
torch.fx
import
GraphModule
from
torch.fx.node
import
Argument
,
Node
,
Target
...
...
@@ -13,14 +13,14 @@ from colossalai._analyzer.fx.node_util import MetaInfo
def
_format_flops
(
flops
:
float
)
->
str
:
"""Returns a formatted FLOP size string"""
if
flops
>
1e12
:
return
f
'
{
flops
/
1e12
:.
2
f
}
TFLOPs
'
return
f
"
{
flops
/
1e12
:.
2
f
}
TFLOPs
"
elif
flops
>
1e9
:
return
f
'
{
flops
/
1e9
:.
2
f
}
GFLOPs
'
return
f
"
{
flops
/
1e9
:.
2
f
}
GFLOPs
"
elif
flops
>
1e6
:
return
f
'
{
flops
/
1e6
:.
2
f
}
MFLOPs
'
return
f
"
{
flops
/
1e6
:.
2
f
}
MFLOPs
"
elif
flops
>
1e3
:
return
f
'
{
flops
/
1e3
:.
2
f
}
kFLOPs
'
return
f
'
{
flops
}
FLOPs
'
return
f
"
{
flops
/
1e3
:.
2
f
}
kFLOPs
"
return
f
"
{
flops
}
FLOPs
"
def
_denormalize_tuple
(
t
:
Tuple
[
int
,
...])
->
Tuple
[
int
,
...]:
...
...
@@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter):
Fetch shape argument from ``ShapeProp`` without re-executing
the ``GraphModule`` from scratch.
"""
_profileable
=
[
'
call_function
'
,
'
call_module
'
,
'
call_method
'
,
"
call_function
"
,
"
call_module
"
,
"
call_method
"
,
]
def
__init__
(
self
,
module
:
GraphModule
,
garbage_collect_values
:
bool
=
True
):
...
...
@@ -77,14 +78,13 @@ class GraphProfiler(torch.fx.Interpreter):
self
.
args_iter
:
Iterator
[
Any
]
=
iter
(
args
)
for
node
in
self
.
module
.
graph
.
nodes
:
self
.
run_node
(
node
)
# No need to store.
if
self
.
garbage_collect_values
:
for
to_delete
in
self
.
user_to_last_uses
.
get
(
node
,
[]):
del
self
.
env
[
to_delete
]
if
node
.
op
==
'
output
'
:
if
node
.
op
==
"
output
"
:
output_val
=
self
.
env
[
node
]
return
self
.
module
.
graph
.
process_outputs
(
output_val
)
if
enable_io_processing
else
output_val
...
...
@@ -133,9 +133,11 @@ class GraphProfiler(torch.fx.Interpreter):
try
:
from
tabulate
import
tabulate
except
ImportError
:
print
(
"`summary` relies on the library `tabulate`, "
print
(
"`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
"install tabulate` to install the library."
)
# Build up a list of summary information for each node
node_summaries
:
List
[
List
[
Any
]]
=
[]
...
...
@@ -145,7 +147,8 @@ class GraphProfiler(torch.fx.Interpreter):
node
:
Node
n_info
=
MetaInfo
(
node
)
last_n_info
=
last_n_info
or
n_info
node_summaries
.
append
([
node_summaries
.
append
(
[
node
.
op
,
str
(
node
),
_format_memory
(
n_info
.
accumulate_size
),
...
...
@@ -156,25 +159,26 @@ class GraphProfiler(torch.fx.Interpreter):
_format_memory
(
n_info
.
backward_size
),
_format_flops
(
n_info
.
fwd_flop
),
_format_flops
(
n_info
.
bwd_flop
),
])
]
)
last_n_info
=
n_info
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers
:
List
[
str
]
=
[
'
Op type
'
,
'
Op
'
,
'
Accumulate size
'
,
'
Incremental size
'
,
'
Output size
'
,
'
Temp size
'
,
'
Param size
'
,
'
Backward size
'
,
'
Fwd FLOPs
'
,
'
Bwd FLOPs
'
,
"
Op type
"
,
"
Op
"
,
"
Accumulate size
"
,
"
Incremental size
"
,
"
Output size
"
,
"
Temp size
"
,
"
Param size
"
,
"
Backward size
"
,
"
Fwd FLOPs
"
,
"
Bwd FLOPs
"
,
]
return
tabulate
(
node_summaries
,
headers
=
headers
,
stralign
=
'
right
'
)
return
tabulate
(
node_summaries
,
headers
=
headers
,
stralign
=
"
right
"
)
class
CommunicationProfiler
(
GraphProfiler
):
...
...
@@ -222,6 +226,7 @@ class FlopProfiler(GraphProfiler):
>>> def my_fn_flop_count_impl(*args, **kwargs):
>>> return 0, 0
"""
_custom_flop_count_impl
=
{}
def
run_node
(
self
,
n
:
torch
.
fx
.
Node
)
->
Any
:
...
...
@@ -246,11 +251,13 @@ class FlopProfiler(GraphProfiler):
(
n_info
.
fwd_flop
,
n_info
.
bwd_flop
,
)
=
getattr
(
self
,
n
.
op
)(
n
.
target
,
args
,
kwargs
)
)
=
getattr
(
self
,
n
.
op
)(
n
.
target
,
args
,
kwargs
)
except
Exception
as
e
:
raise
RuntimeError
(
f
'
Error
{
str
(
e
)
}
occurred when profiling node
{
n
}
, node.target =
{
n
.
target
}
.
'
f
'
Please refer to function
\
'
s docstring to register the relevant profile_impl for this node!
'
f
"
Error
{
str
(
e
)
}
occurred when profiling node
{
n
}
, node.target =
{
n
.
target
}
.
"
f
"
Please refer to function's docstring to register the relevant profile_impl for this node!
"
)
from
e
# retain the autograd graph
...
...
@@ -259,7 +266,7 @@ class FlopProfiler(GraphProfiler):
return
_denormalize_tuple
(
n_info
.
outputs
)
def
call_function
(
self
,
target
:
'
Target
'
,
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
def
call_function
(
self
,
target
:
"
Target
"
,
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
"""
Execute a ``call_function`` node and return the profiling result.
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
...
...
@@ -283,7 +290,7 @@ class FlopProfiler(GraphProfiler):
else
:
return
flop_count
(
target
,
*
args
,
**
kwargs
)
def
call_method
(
self
,
target
:
'
Target
'
,
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
def
call_method
(
self
,
target
:
"
Target
"
,
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
"""
Execute a ``call_method`` node and return the profiling result.
...
...
@@ -301,7 +308,7 @@ class FlopProfiler(GraphProfiler):
assert
isinstance
(
target
,
str
)
return
flop_count
(
getattr
(
torch
.
Tensor
,
target
),
*
args
,
**
kwargs
)
def
call_module
(
self
,
target
:
'
Target
'
,
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
def
call_module
(
self
,
target
:
"
Target
"
,
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
"""
Execute a ``call_module`` node and return the profiling result.
...
...
@@ -336,7 +343,8 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule
Returns:
GraphModule: The same GraphModule with profiling information
"""
for
profiler_cls
in
(
FlopProfiler
,
for
profiler_cls
in
(
FlopProfiler
,
# CommunicationProfiler, # TODO: add communication profiling
):
profiler
=
profiler_cls
(
module
)
...
...
colossalai/_analyzer/fx/passes/shape_prop.py
View file @
9e768b59
...
...
@@ -54,7 +54,7 @@ def _current_device(module):
try
:
return
next
(
module
.
parameters
()).
device
except
StopIteration
:
return
torch
.
device
(
'
cpu
'
)
return
torch
.
device
(
"
cpu
"
)
@
compatibility
(
is_backward_compatible
=
False
)
...
...
@@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter):
>>> # do something here
>>> return torch.empty(output_shape, device=output_device)
"""
_custom_dispatch_func
=
{}
_mode
=
MetaTensorMode
()
...
...
@@ -115,15 +116,14 @@ class ShapeProp(torch.fx.Interpreter):
r
=
getattr
(
self
,
n
.
op
)(
n
.
target
,
args
,
kwargs
)
def
unwrap_fn
(
elem
):
def
_convert_meta
(
t
:
torch
.
Tensor
):
if
t
.
device
==
'
meta
'
:
if
t
.
device
==
"
meta
"
:
return
t
else
:
return
t
.
to
(
'
meta
'
)
return
t
.
to
(
"
meta
"
)
if
isinstance
(
elem
,
MetaTensor
):
if
getattr
(
self
,
'
_is_param
'
,
False
):
if
getattr
(
self
,
"
_is_param
"
,
False
):
return
torch
.
nn
.
Parameter
(
_convert_meta
(
elem
.
_tensor
))
return
_convert_meta
(
elem
.
_tensor
)
...
...
@@ -139,21 +139,24 @@ class ShapeProp(torch.fx.Interpreter):
n_info
=
MetaInfo
(
n
)
n_info
.
outputs
=
_normalize_tuple
(
r
)
if
n
.
op
==
'
call_module
'
:
if
n
.
op
==
"
call_module
"
:
submod
=
self
.
fetch_attr
(
n
.
target
)
n_info
.
parameters
.
update
({
k
:
MetaTensor
(
v
)
for
k
,
v
in
submod
.
named_parameters
()})
n_info
.
buffers
.
update
({
k
:
MetaTensor
(
v
)
for
k
,
v
in
submod
.
named_buffers
()})
else
:
n_info
.
parameters
.
update
({
n_info
.
parameters
.
update
(
{
k
.
name
:
MetaTensor
(
v
)
for
k
,
v
in
zip
(
n
.
args
,
args
)
if
isinstance
(
k
,
torch
.
fx
.
Node
)
and
isinstance
(
v
,
torch
.
nn
.
Parameter
)
})
}
)
n_info
.
parameters
.
update
({
k
:
MetaTensor
(
v
)
for
k
,
v
in
kwargs
.
items
()
if
isinstance
(
v
,
torch
.
nn
.
Parameter
)})
n_info
.
inputs
=
tuple
(
v
for
v
in
args
if
is_pure_tensor
(
v
))
+
\
tuple
(
v
for
v
in
kwargs
.
values
()
if
is_pure_tensor
(
v
))
n_info
.
inputs
=
tuple
(
v
for
v
in
args
if
is_pure_tensor
(
v
))
+
tuple
(
v
for
v
in
kwargs
.
values
()
if
is_pure_tensor
(
v
)
)
# align with SPMD
if
isinstance
(
r
,
(
tuple
,
list
)):
...
...
@@ -168,7 +171,7 @@ class ShapeProp(torch.fx.Interpreter):
n_info
.
is_alias
=
_normalize_tuple
(
tree_map
(
crit
,
n_info
.
outputs
))
return
r
def
call_function
(
self
,
target
:
'
Target
'
,
args
:
Tuple
[
Any
,
...],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
def
call_function
(
self
,
target
:
"
Target
"
,
args
:
Tuple
[
Any
,
...],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
"""
Execute a ``call_function`` node and return the result.
If the target of ``Node`` is registered with ``@register_shape_impl``,
...
...
@@ -197,7 +200,7 @@ class ShapeProp(torch.fx.Interpreter):
else
:
return
res
def
call_method
(
self
,
target
:
'
Target
'
,
args
:
Tuple
[
Any
,
...],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
def
call_method
(
self
,
target
:
"
Target
"
,
args
:
Tuple
[
Any
,
...],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
"""
Execute a ``call_method`` node and return the result.
...
...
@@ -218,7 +221,8 @@ class ShapeProp(torch.fx.Interpreter):
convert_to_parameter
=
False
if
target_method
in
(
torch
.
Tensor
.
view
,
torch
.
Tensor
.
transpose
)
and
isinstance
(
args
[
0
],
torch
.
nn
.
parameter
.
Parameter
):
args
[
0
],
torch
.
nn
.
parameter
.
Parameter
):
convert_to_parameter
=
True
# Execute the method and return the result
assert
isinstance
(
target
,
str
)
...
...
colossalai/_analyzer/fx/symbolic_profile.py
View file @
9e768b59
import
torch
import
torch.fx
from
torch.fx
import
GraphModule
from
.passes
import
ShapeProp
,
graph_profile_pass
,
shape_prop_pass
...
...
@@ -7,7 +5,6 @@ from .passes.graph_profile import FlopProfiler
def
register_flop_count_impl
(
func
):
def
wrapper
(
impl
):
FlopProfiler
.
_custom_flop_count_impl
[
func
]
=
impl
return
impl
...
...
@@ -16,7 +13,6 @@ def register_flop_count_impl(func):
def
register_shape_impl
(
func
):
def
wrapper
(
impl
):
ShapeProp
.
_custom_dispatch_func
[
func
]
=
impl
return
impl
...
...
colossalai/_analyzer/fx/tracer/bias_addition.py
View file @
9e768b59
...
...
@@ -12,7 +12,7 @@ from .tracer import register_tracer_impl
__all__
=
[]
@
register_tracer_impl
(
F
.
linear
,
name
=
'
_bias_addition_impl
'
)
@
register_tracer_impl
(
F
.
linear
,
name
=
"
_bias_addition_impl
"
)
def
linear_impl
(
input
,
weight
,
bias
=
None
):
if
bias
is
None
:
return
F
.
linear
(
input
,
weight
)
...
...
@@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None):
return
F
.
linear
(
input
,
weight
)
+
bias
@
register_tracer_impl
(
F
.
conv1d
,
name
=
'
_bias_addition_impl
'
)
@
register_tracer_impl
(
F
.
conv1d
,
name
=
"
_bias_addition_impl
"
)
def
conv1d_impl
(
input
,
weight
,
bias
=
None
,
stride
=
_single
(
1
),
padding
=
_single
(
0
),
dilation
=
_single
(
1
),
groups
=
1
):
if
bias
is
None
:
return
F
.
conv1d
(
input
,
weight
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
)
else
:
return
F
.
conv1d
(
input
,
weight
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
)
+
bias
.
reshape
(
(
-
1
,
1
))
(
-
1
,
1
)
)
@
register_tracer_impl
(
F
.
conv2d
,
name
=
'
_bias_addition_impl
'
)
@
register_tracer_impl
(
F
.
conv2d
,
name
=
"
_bias_addition_impl
"
)
def
conv2d_impl
(
input
,
weight
,
bias
=
None
,
stride
=
_pair
(
1
),
padding
=
_pair
(
0
),
dilation
=
_pair
(
1
),
groups
=
1
):
if
bias
is
None
:
return
F
.
conv2d
(
input
,
weight
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
)
else
:
return
F
.
conv2d
(
input
,
weight
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
)
+
bias
.
reshape
(
(
-
1
,
1
,
1
))
(
-
1
,
1
,
1
)
)
@
register_tracer_impl
(
F
.
conv3d
,
name
=
'
_bias_addition_impl
'
)
@
register_tracer_impl
(
F
.
conv3d
,
name
=
"
_bias_addition_impl
"
)
def
conv3d_impl
(
input
,
weight
,
bias
=
None
,
stride
=
_triple
(
1
),
padding
=
_triple
(
0
),
dilation
=
_triple
(
1
),
groups
=
1
):
if
bias
is
None
:
return
F
.
conv3d
(
input
,
weight
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
)
else
:
return
F
.
conv3d
(
input
,
weight
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
)
+
bias
.
reshape
(
(
-
1
,
1
,
1
,
1
))
(
-
1
,
1
,
1
,
1
)
)
@
register_tracer_impl
(
F
.
conv_transpose1d
,
name
=
'_bias_addition_impl'
)
def
conv_transpose1d_impl
(
input
,
@
register_tracer_impl
(
F
.
conv_transpose1d
,
name
=
"_bias_addition_impl"
)
def
conv_transpose1d_impl
(
input
,
weight
,
bias
=
None
,
stride
=
_single
(
1
),
padding
=
_single
(
0
),
output_padding
=
_single
(
0
),
groups
=
1
,
dilation
=
_single
(
1
)):
dilation
=
_single
(
1
),
):
if
bias
is
None
:
return
F
.
conv_transpose1d
(
input
,
return
F
.
conv_transpose1d
(
input
,
weight
,
stride
=
stride
,
padding
=
padding
,
output_padding
=
output_padding
,
groups
=
groups
,
dilation
=
dilation
)
dilation
=
dilation
,
)
else
:
return
F
.
conv_transpose1d
(
input
,
return
F
.
conv_transpose1d
(
input
,
weight
,
stride
=
stride
,
padding
=
padding
,
output_padding
=
output_padding
,
groups
=
groups
,
dilation
=
dilation
)
+
bias
.
reshape
((
-
1
,
1
))
dilation
=
dilation
,
)
+
bias
.
reshape
((
-
1
,
1
))
@
register_tracer_impl
(
F
.
conv_transpose2d
,
name
=
'_bias_addition_impl'
)
def
conv_transpose2d_impl
(
input
,
weight
,
bias
=
None
,
stride
=
_pair
(
1
),
padding
=
_pair
(
0
),
output_padding
=
_pair
(
0
),
groups
=
1
,
dilation
=
_pair
(
1
)):
@
register_tracer_impl
(
F
.
conv_transpose2d
,
name
=
"_bias_addition_impl"
)
def
conv_transpose2d_impl
(
input
,
weight
,
bias
=
None
,
stride
=
_pair
(
1
),
padding
=
_pair
(
0
),
output_padding
=
_pair
(
0
),
groups
=
1
,
dilation
=
_pair
(
1
)
):
if
bias
is
None
:
return
F
.
conv_transpose2d
(
input
,
return
F
.
conv_transpose2d
(
input
,
weight
,
stride
=
stride
,
padding
=
padding
,
output_padding
=
output_padding
,
groups
=
groups
,
dilation
=
dilation
)
dilation
=
dilation
,
)
else
:
return
F
.
conv_transpose2d
(
input
,
return
F
.
conv_transpose2d
(
input
,
weight
,
stride
=
stride
,
padding
=
padding
,
output_padding
=
output_padding
,
groups
=
groups
,
dilation
=
dilation
)
+
bias
.
reshape
((
-
1
,
1
,
1
))
dilation
=
dilation
,
)
+
bias
.
reshape
((
-
1
,
1
,
1
))
@
register_tracer_impl
(
F
.
conv_transpose3d
,
name
=
'_bias_addition_impl'
)
def
conv_transpose3d_impl
(
input
,
@
register_tracer_impl
(
F
.
conv_transpose3d
,
name
=
"_bias_addition_impl"
)
def
conv_transpose3d_impl
(
input
,
weight
,
bias
=
None
,
stride
=
_triple
(
1
),
padding
=
_triple
(
0
),
output_padding
=
_triple
(
0
),
groups
=
1
,
dilation
=
_triple
(
1
)):
dilation
=
_triple
(
1
),
):
if
bias
is
None
:
return
F
.
conv_transpose3d
(
input
,
return
F
.
conv_transpose3d
(
input
,
weight
,
stride
=
stride
,
padding
=
padding
,
output_padding
=
output_padding
,
groups
=
groups
,
dilation
=
dilation
)
dilation
=
dilation
,
)
else
:
return
F
.
conv_transpose3d
(
input
,
return
F
.
conv_transpose3d
(
input
,
weight
,
stride
=
stride
,
padding
=
padding
,
output_padding
=
output_padding
,
groups
=
groups
,
dilation
=
dilation
)
+
bias
.
reshape
((
-
1
,
1
,
1
,
1
))
dilation
=
dilation
,
)
+
bias
.
reshape
((
-
1
,
1
,
1
,
1
))
@
register_tracer_impl
(
torch
.
addmm
,
name
=
'
_bias_addition_impl
'
)
@
register_tracer_impl
(
torch
.
Tensor
.
addmm
,
name
=
'
_bias_addition_impl
'
)
@
register_tracer_impl
(
torch
.
addmm
,
name
=
"
_bias_addition_impl
"
)
@
register_tracer_impl
(
torch
.
Tensor
.
addmm
,
name
=
"
_bias_addition_impl
"
)
def
addmm_impl
(
input
,
mat1
,
mat2
,
beta
=
1
,
alpha
=
1
):
if
alpha
!=
1
and
beta
!=
1
:
return
F
.
linear
(
mat1
,
mat2
.
transpose
(
0
,
1
))
*
alpha
+
input
*
beta
...
...
@@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
return
F
.
linear
(
mat1
,
mat2
.
transpose
(
0
,
1
))
+
input
@
register_tracer_impl
(
torch
.
addbmm
,
name
=
'
_bias_addition_impl
'
)
@
register_tracer_impl
(
torch
.
Tensor
.
addbmm
,
name
=
'
_bias_addition_impl
'
)
@
register_tracer_impl
(
torch
.
addbmm
,
name
=
"
_bias_addition_impl
"
)
@
register_tracer_impl
(
torch
.
Tensor
.
addbmm
,
name
=
"
_bias_addition_impl
"
)
def
addbmm_impl
(
input
,
batch1
,
batch2
,
beta
=
1
,
alpha
=
1
):
if
alpha
!=
1
and
beta
!=
1
:
return
torch
.
bmm
(
batch1
,
batch2
.
transpose
(
1
,
2
))
*
alpha
+
input
*
beta
...
...
colossalai/_analyzer/fx/tracer/custom_leaf_module.py
View file @
9e768b59
...
...
@@ -4,6 +4,7 @@ from .tracer import register_leaf_module, register_leaf_module_impl
try
:
import
apex
register_leaf_module
(
apex
.
normalization
.
FusedLayerNorm
)
register_leaf_module
(
apex
.
normalization
.
FusedRMSNorm
)
register_leaf_module
(
apex
.
normalization
.
MixedFusedLayerNorm
)
...
...
colossalai/_analyzer/fx/tracer/proxy.py
View file @
9e768b59
import
operator
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Set
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
from
torch.fx
import
Graph
,
Node
,
Proxy
,
Tracer
from
torch.fx.graph
import
_Namespace
from
torch.fx
import
Node
,
Proxy
from
torch.utils._pytree
import
tree_map
from
colossalai._analyzer._subclasses
import
MetaTensor
...
...
@@ -72,7 +70,7 @@ class ColoProxy(Proxy):
return
ColoAttribute
(
self
,
k
,
getattr
(
self
.
_meta_data
,
k
,
None
))
def
__setitem__
(
self
,
key
,
value
):
proxy
=
self
.
tracer
.
create_proxy
(
'
call_function
'
,
operator
.
setitem
,
(
self
,
key
,
value
),
{})
proxy
=
self
.
tracer
.
create_proxy
(
"
call_function
"
,
operator
.
setitem
,
(
self
,
key
,
value
),
{})
proxy
.
meta_data
=
self
.
_meta_data
return
proxy
...
...
@@ -89,7 +87,6 @@ class ColoProxy(Proxy):
class
ColoAttribute
(
ColoProxy
):
def
__init__
(
self
,
root
,
attr
:
str
,
data
=
None
):
self
.
root
=
root
self
.
attr
=
attr
...
...
@@ -102,11 +99,11 @@ class ColoAttribute(ColoProxy):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if
self
.
_node
is
None
:
self
.
_node
=
self
.
tracer
.
create_proxy
(
'
call_function
'
,
getattr
,
(
self
.
root
,
self
.
attr
),
{}).
node
self
.
_node
=
self
.
tracer
.
create_proxy
(
"
call_function
"
,
getattr
,
(
self
.
root
,
self
.
attr
),
{}).
node
return
self
.
_node
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
tracer
.
create_proxy
(
'
call_method
'
,
self
.
attr
,
(
self
.
root
,)
+
args
,
kwargs
)
return
self
.
tracer
.
create_proxy
(
"
call_method
"
,
self
.
attr
,
(
self
.
root
,)
+
args
,
kwargs
)
def
__repr__
(
self
):
return
f
"ColoAttribute(
{
self
.
node
.
name
}
, attr=
{
self
.
attr
}
)"
colossalai/_analyzer/fx/tracer/symbolic_trace.py
View file @
9e768b59
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Union
import
torch
from
torch.fx
import
Tracer
...
...
@@ -8,6 +8,7 @@ from colossalai._analyzer._subclasses import MetaTensor
try
:
from
..codegen
import
ActivationCheckpointCodeGen
SUPPORT_ACTIVATION
=
True
except
:
SUPPORT_ACTIVATION
=
False
...
...
@@ -16,7 +17,7 @@ from .tracer import ColoTracer
def
_default_device
():
return
torch
.
device
(
'
cuda:0
'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'
cpu
'
)
return
torch
.
device
(
"
cuda:0
"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"
cpu
"
)
def
_current_device
(
module
:
torch
.
nn
.
Module
):
...
...
@@ -144,10 +145,9 @@ def symbolic_trace(
if
meta_args
:
device
,
orig_device
=
_default_device
(),
_current_device
(
root
)
wrap_fn
=
lambda
elem
:
MetaTensor
(
elem
,
device
=
device
)
if
isinstance
(
elem
,
torch
.
Tensor
)
else
elem
graph
=
ColoTracer
(
trace_act_ckpt
=
trace_act_ckpt
,
bias_addition_split
=
bias_addition_split
).
trace
(
root
.
to
(
device
),
concrete_args
=
concrete_args
,
meta_args
=
tree_map
(
wrap_fn
,
meta_args
))
graph
=
ColoTracer
(
trace_act_ckpt
=
trace_act_ckpt
,
bias_addition_split
=
bias_addition_split
).
trace
(
root
.
to
(
device
),
concrete_args
=
concrete_args
,
meta_args
=
tree_map
(
wrap_fn
,
meta_args
)
)
if
trace_act_ckpt
and
SUPPORT_ACTIVATION
:
graph
.
set_codegen
(
ActivationCheckpointCodeGen
())
root
.
to
(
orig_device
)
...
...
colossalai/_analyzer/fx/tracer/tracer.py
View file @
9e768b59
...
...
@@ -20,11 +20,10 @@ def _truncate_suffix(s: str):
import
re
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
return
re
.
sub
(
r
'
_\d+$
'
,
''
,
s
)
return
re
.
sub
(
r
"
_\d+$
"
,
""
,
s
)
def
register_tracer_impl
(
func
:
Callable
[...,
Any
],
name
:
Optional
[
str
]
=
'_custom_impl'
):
def
register_tracer_impl
(
func
:
Callable
[...,
Any
],
name
:
Optional
[
str
]
=
"_custom_impl"
):
def
wrapper
(
impl
):
assert
hasattr
(
ColoTracer
,
name
),
f
"Cannot register
{
func
.
__name__
}
in ColoTracer.
{
name
}
"
getattr
(
ColoTracer
,
name
)[
func
]
=
impl
...
...
@@ -34,7 +33,6 @@ def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custo
def
register_leaf_module_impl
(
module
:
nn
.
Module
):
def
wrapper
(
impl
):
ColoTracer
.
_custom_leaf_module_impl
[
module
]
=
impl
return
impl
...
...
@@ -76,7 +74,7 @@ class ColoTracer(Tracer):
self
.
ckpt_regions
=
[]
self
.
ckpt_idx
=
0
self
.
mod_dir
=
''
self
.
mod_dir
=
""
# whether the tracer should split the bias_add ops into two ops
self
.
bias_addition_split
=
bias_addition_split
...
...
@@ -87,35 +85,41 @@ class ColoTracer(Tracer):
if
self
.
bias_addition_split
and
type
(
m
)
in
self
.
_bias_addition_module
and
m
.
bias
is
not
None
:
return
False
# user can specify which modules are leaf modules and which are not
return
(
type
(
m
)
not
in
self
.
_custom_non_leaf_module
and
(
type
(
m
)
in
self
.
_custom_leaf_module
or
super
().
is_leaf_module
(
m
,
module_qualified_name
)))
return
type
(
m
)
not
in
self
.
_custom_non_leaf_module
and
(
type
(
m
)
in
self
.
_custom_leaf_module
or
super
().
is_leaf_module
(
m
,
module_qualified_name
)
)
def
call_module
(
self
,
m
:
torch
.
nn
.
Module
,
forward
:
Callable
[...,
Any
],
args
:
Tuple
[
Any
,
...],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
def
call_module
(
self
,
m
:
torch
.
nn
.
Module
,
forward
:
Callable
[...,
Any
],
args
:
Tuple
[
Any
,
...],
kwargs
:
Dict
[
str
,
Any
]
)
->
Any
:
curr_dir
=
self
.
mod_dir
self
.
mod_dir
=
'
self.
'
+
self
.
path_of_module
(
m
)
self
.
mod_dir
=
"
self.
"
+
self
.
path_of_module
(
m
)
rst
=
super
().
call_module
(
m
,
forward
,
args
,
kwargs
)
self
.
mod_dir
=
curr_dir
return
rst
def
proxy
(
self
,
node
:
Node
)
->
'
ColoProxy
'
:
def
proxy
(
self
,
node
:
Node
)
->
"
ColoProxy
"
:
return
ColoProxy
(
node
,
self
)
def
create_proxy
(
self
,
def
create_proxy
(
self
,
kind
:
str
,
target
:
Target
,
args
:
Tuple
[
Any
,
...],
kwargs
:
Dict
[
str
,
Any
],
name
:
Optional
[
str
]
=
None
,
type_expr
:
Optional
[
Any
]
=
None
,
proxy_factory_fn
:
Callable
[[
Node
],
'
Proxy
'
]
=
None
):
proxy_factory_fn
:
Callable
[[
Node
],
"
Proxy
"
]
=
None
,
):
proxy
:
ColoProxy
=
super
().
create_proxy
(
kind
,
target
,
args
,
kwargs
,
name
,
type_expr
,
proxy_factory_fn
)
unwrap_fn
=
lambda
p
:
p
.
meta_data
if
isinstance
(
p
,
ColoProxy
)
else
p
if
kind
==
'placeholder'
:
proxy
.
meta_data
=
self
.
meta_args
[
target
]
if
target
in
self
.
meta_args
else
self
.
concrete_args
.
get
(
_truncate_suffix
(
target
),
None
)
elif
kind
==
'get_attr'
:
if
kind
==
"placeholder"
:
proxy
.
meta_data
=
(
self
.
meta_args
[
target
]
if
target
in
self
.
meta_args
else
self
.
concrete_args
.
get
(
_truncate_suffix
(
target
),
None
)
)
elif
kind
==
"get_attr"
:
self
.
disable_module_getattr
=
True
try
:
attr_itr
=
self
.
root
...
...
@@ -125,20 +129,21 @@ class ColoTracer(Tracer):
proxy
.
meta_data
=
attr_itr
finally
:
self
.
disable_module_getattr
=
False
elif
kind
==
'
call_function
'
:
elif
kind
==
"
call_function
"
:
proxy
.
meta_data
=
target
(
*
tree_map
(
unwrap_fn
,
args
),
**
tree_map
(
unwrap_fn
,
kwargs
))
elif
kind
==
'
call_method
'
:
elif
kind
==
"
call_method
"
:
self
.
disable_module_getattr
=
True
try
:
if
target
==
'
__call__
'
:
if
target
==
"
__call__
"
:
proxy
.
meta_data
=
unwrap_fn
(
args
[
0
])(
*
tree_map
(
unwrap_fn
,
args
[
1
:]),
**
tree_map
(
unwrap_fn
,
kwargs
))
else
:
if
target
not
in
_TensorPropertyMethod
:
proxy
.
_meta_data
=
getattr
(
unwrap_fn
(
args
[
0
]),
target
)(
*
tree_map
(
unwrap_fn
,
args
[
1
:]),
**
tree_map
(
unwrap_fn
,
kwargs
))
proxy
.
_meta_data
=
getattr
(
unwrap_fn
(
args
[
0
]),
target
)(
*
tree_map
(
unwrap_fn
,
args
[
1
:]),
**
tree_map
(
unwrap_fn
,
kwargs
)
)
finally
:
self
.
disable_module_getattr
=
False
elif
kind
==
'
call_module
'
:
elif
kind
==
"
call_module
"
:
mod
=
self
.
root
.
get_submodule
(
target
)
self
.
disable_module_getattr
=
True
try
:
...
...
@@ -158,11 +163,12 @@ class ColoTracer(Tracer):
n_info
=
MetaInfo
(
node
,
mod_dir
=
self
.
mod_dir
,
activation_checkpoint
=
tuple
(
self
.
ckpt_regions
))
return
node
def
trace
(
self
,
def
trace
(
self
,
root
:
torch
.
nn
.
Module
,
concrete_args
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
meta_args
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
)
->
Graph
:
meta_args
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
)
->
Graph
:
if
meta_args
is
None
:
meta_args
=
{}
...
...
@@ -177,9 +183,7 @@ class ColoTracer(Tracer):
non_concrete_arg_names
=
sig_names
-
concrete_arg_names
# update concrete args with default values
for
k
,
v
in
sig
.
parameters
.
items
():
if
k
in
sig_names
-
meta_arg_names
and
\
k
not
in
concrete_args
and
\
v
.
default
is
not
inspect
.
Parameter
.
empty
:
if
k
in
sig_names
-
meta_arg_names
and
k
not
in
concrete_args
and
v
.
default
is
not
inspect
.
Parameter
.
empty
:
concrete_args
[
k
]
=
v
.
default
def
_check_arg_name_valid
(
names
:
Iterable
[
str
]):
...
...
@@ -194,9 +198,9 @@ class ColoTracer(Tracer):
self
.
meta_args
=
meta_args
with
self
.
_torch_factory_override
(),
self
.
_tracer_override
(),
torch
.
no_grad
():
self
.
mod_dir
=
'
self
'
self
.
mod_dir
=
"
self
"
self
.
graph
=
super
().
trace
(
root
,
concrete_args
=
concrete_args
)
self
.
mod_dir
=
''
self
.
mod_dir
=
""
self
.
graph
.
lint
()
for
node
in
self
.
graph
.
nodes
:
...
...
@@ -266,17 +270,17 @@ class ColoTracer(Tracer):
# override the torch factory functions to create a proxy when the method
# is called during ``symbolic_trace()``.
def
wrap_factory_method
(
target
):
@
functools
.
wraps
(
target
)
def
wrapper
(
*
args
,
**
kwargs
):
is_proxy
=
any
(
isinstance
(
p
,
ColoProxy
)
for
p
in
args
)
|
any
(
isinstance
(
p
,
ColoProxy
)
for
p
in
kwargs
.
values
())
isinstance
(
p
,
ColoProxy
)
for
p
in
kwargs
.
values
()
)
if
is_proxy
:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self
.
disable_module_getattr
=
True
try
:
proxy
=
self
.
create_proxy
(
'
call_function
'
,
target
,
args
,
kwargs
)
proxy
=
self
.
create_proxy
(
"
call_function
"
,
target
,
args
,
kwargs
)
finally
:
self
.
disable_module_getattr
=
False
return
proxy
...
...
@@ -341,10 +345,13 @@ class ColoTracer(Tracer):
if
attr_val
is
p
:
if
n
not
in
parameter_proxy_cache
:
kwargs
=
{}
if
'proxy_factory_fn'
in
inspect
.
signature
(
self
.
create_proxy
).
parameters
:
kwargs
[
'proxy_factory_fn'
]
=
(
None
if
not
self
.
param_shapes_constant
else
lambda
node
:
ColoProxy
(
self
,
node
,
n
,
attr_val
))
val_proxy
=
self
.
create_proxy
(
'get_attr'
,
n
,
(),
{},
**
kwargs
)
# type: ignore[arg-type]
if
"proxy_factory_fn"
in
inspect
.
signature
(
self
.
create_proxy
).
parameters
:
kwargs
[
"proxy_factory_fn"
]
=
(
None
if
not
self
.
param_shapes_constant
else
lambda
node
:
ColoProxy
(
self
,
node
,
n
,
attr_val
)
)
val_proxy
=
self
.
create_proxy
(
"get_attr"
,
n
,
(),
{},
**
kwargs
)
# type: ignore[arg-type]
parameter_proxy_cache
[
n
]
=
val_proxy
return
parameter_proxy_cache
[
n
]
return
None
...
...
@@ -355,8 +362,9 @@ class ColoTracer(Tracer):
return
maybe_buffer_proxy
if
isinstance
(
attr_val
,
torch
.
nn
.
Parameter
):
maybe_parameter_proxy
=
maybe_get_proxy_for_attr
(
attr_val
,
self
.
root
.
named_parameters
(),
parameter_proxy_cache
)
maybe_parameter_proxy
=
maybe_get_proxy_for_attr
(
attr_val
,
self
.
root
.
named_parameters
(),
parameter_proxy_cache
)
if
maybe_parameter_proxy
is
not
None
:
return
maybe_parameter_proxy
...
...
colossalai/amp/__init__.py
View file @
9e768b59
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch.nn
as
nn
from
torch.nn.modules.loss
import
_Loss
from
torch.optim
import
Optimizer
from
colossalai.context
import
Config
from
.amp_type
import
AMP_TYPE
from
.apex_amp
import
convert_to_apex_amp
from
.naive_amp
import
convert_to_naive_amp
from
.torch_amp
import
convert_to_torch_amp
__all__
=
[
'convert_to_amp'
,
'convert_to_naive_amp'
,
'convert_to_apex_amp'
,
'convert_to_torch_amp'
,
'AMP_TYPE'
]
def
convert_to_amp
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
criterion
:
_Loss
,
mode
:
AMP_TYPE
,
amp_config
:
Config
=
None
):
"""A helper function to wrap training components with Torch AMP modules.
Args:
param model (:class:`torch.nn.Module`): your model object.
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object.
mode (:class:`colossalai.amp.AMP_TYPE`): amp mode.
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes.
Returns:
A tuple (model, optimizer, criterion).
Note:
``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode
for more details about ``amp_config``.
For ``apex_amp``, please check
`apex_amp config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
For ``naive_amp``, please check
`naive_amp config <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/amp/naive_amp/_fp16_optimizer.py#L42>`_.
For ``torch_amp``, please check
`torch_amp config <https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py#L97>`_.
"""
assert
isinstance
(
mode
,
AMP_TYPE
),
\
f
'expected the argument mode be AMP_TYPE, but got
{
type
(
mode
)
}
'
if
amp_config
is
None
:
amp_config
=
Config
()
if
mode
==
AMP_TYPE
.
TORCH
:
model
,
optimizer
,
criterion
=
convert_to_torch_amp
(
model
,
optimizer
,
criterion
,
amp_config
)
elif
mode
==
AMP_TYPE
.
APEX
:
model
,
optimizer
=
convert_to_apex_amp
(
model
,
optimizer
,
amp_config
)
elif
mode
==
AMP_TYPE
.
NAIVE
:
model
,
optimizer
=
convert_to_naive_amp
(
model
,
optimizer
,
amp_config
)
return
model
,
optimizer
,
criterion
colossalai/amp/naive_amp/__init__.py
View file @
9e768b59
import
inspect
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
from
colossalai.utils
import
is_no_pp_or_last_stage
from
._fp16_optimizer
import
FP16Optimizer
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.naive_amp
import
NaiveAMPModel
,
NaiveAMPOptimizer
def
convert_to_naive_amp
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
amp_config
):
"""A helper function to wrap training components with naive AMP modules. In this mode,
we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
which is equivalent to Apex O3.
Args:
model (:class:`torch.nn.Module`): your model object
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
Returns:
Tuple: A tuple (model, optimizer)
The ``amp_config`` should contain parameters below::
verbose (bool, optional): if set to `True`, will print debug info (Default: False).
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
Note that clipping is ignored if clip_grad == 0.
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
"""
if
isinstance
(
model
,
nn
.
ModuleList
):
# interleaved pipeline
module_list
=
[]
for
chunk
,
m
in
enumerate
(
model
):
output_to_fp32
=
is_no_pp_or_last_stage
()
and
chunk
==
len
(
model
)
-
1
module_list
.
append
(
NaiveAMPModel
(
m
,
output_to_fp32
=
output_to_fp32
))
model
=
nn
.
ModuleList
(
module_list
)
else
:
output_to_fp32
=
is_no_pp_or_last_stage
()
model
=
NaiveAMPModel
(
model
,
output_to_fp32
=
output_to_fp32
)
use_dynamic_grad_scaler
=
amp_config
.
pop
(
'dynamic_grad_scale'
,
True
)
if
use_dynamic_grad_scaler
:
scaler_class
=
DynamicGradScaler
else
:
scaler_class
=
ConstantGradScaler
sig
=
inspect
.
signature
(
scaler_class
.
__init__
)
kwargs
=
dict
()
for
param
in
sig
.
parameters
.
values
():
if
param
.
name
in
amp_config
:
kwargs
[
param
.
name
]
=
amp_config
.
pop
(
param
.
name
)
grad_scaler
=
scaler_class
(
**
kwargs
)
optimizer
=
NaiveAMPOptimizer
(
optimizer
,
grad_scaler
,
**
amp_config
)
return
model
,
optimizer
__all__
=
[
'convert_to_naive_amp'
,
'NaiveAMPOptimizer'
,
'FP16Optimizer'
]
colossalai/amp/naive_amp/grad_scaler/__init__.py
View file @
9e768b59
...
...
@@ -2,4 +2,4 @@ from .base_grad_scaler import BaseGradScaler
from
.constant_grad_scaler
import
ConstantGradScaler
from
.dynamic_grad_scaler
import
DynamicGradScaler
__all__
=
[
'
BaseGradScaler
'
,
'
ConstantGradScaler
'
,
'
DynamicGradScaler
'
]
__all__
=
[
"
BaseGradScaler
"
,
"
ConstantGradScaler
"
,
"
DynamicGradScaler
"
]
Prev
1
…
10
11
12
13
14
15
16
17
18
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment