Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e532679c
Commit
e532679c
authored
Jan 10, 2023
by
oahzxl
Browse files
Merge branch 'main' of
https://github.com/oahzxl/ColossalAI
into chunk
parents
c1492e50
7d5640b9
Changes
461
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
818 additions
and
661 deletions
+818
-661
colossalai/fx/tracer/meta_patch/patched_function/__init__.py
colossalai/fx/tracer/meta_patch/patched_function/__init__.py
+1
-2
colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
...tracer/meta_patch/patched_function/activation_function.py
+3
-2
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
...salai/fx/tracer/meta_patch/patched_function/arithmetic.py
+22
-2
colossalai/fx/tracer/meta_patch/patched_function/convolution.py
...alai/fx/tracer/meta_patch/patched_function/convolution.py
+5
-3
colossalai/fx/tracer/meta_patch/patched_function/embedding.py
...ssalai/fx/tracer/meta_patch/patched_function/embedding.py
+3
-2
colossalai/fx/tracer/meta_patch/patched_function/normalization.py
...ai/fx/tracer/meta_patch/patched_function/normalization.py
+3
-2
colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
...salai/fx/tracer/meta_patch/patched_function/python_ops.py
+4
-1
colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
...ssalai/fx/tracer/meta_patch/patched_function/torch_ops.py
+2
-1
colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
...x/tracer/meta_patch/patched_module/activation_function.py
+2
-1
colossalai/fx/tracer/meta_patch/patched_module/convolution.py
...ssalai/fx/tracer/meta_patch/patched_module/convolution.py
+3
-1
colossalai/fx/tracer/meta_patch/patched_module/embedding.py
colossalai/fx/tracer/meta_patch/patched_module/embedding.py
+3
-2
colossalai/fx/tracer/meta_patch/patched_module/linear.py
colossalai/fx/tracer/meta_patch/patched_module/linear.py
+2
-1
colossalai/fx/tracer/meta_patch/patched_module/normalization.py
...alai/fx/tracer/meta_patch/patched_module/normalization.py
+2
-1
colossalai/fx/tracer/meta_patch/patched_module/pooling.py
colossalai/fx/tracer/meta_patch/patched_module/pooling.py
+3
-1
colossalai/fx/tracer/meta_patch/patched_module/rnn.py
colossalai/fx/tracer/meta_patch/patched_module/rnn.py
+4
-2
colossalai/fx/tracer/registry.py
colossalai/fx/tracer/registry.py
+3
-0
colossalai/fx/tracer/tracer.py
colossalai/fx/tracer/tracer.py
+168
-81
colossalai/gemini/__init__.py
colossalai/gemini/__init__.py
+6
-3
colossalai/gemini/chunk/__init__.py
colossalai/gemini/chunk/__init__.py
+3
-1
colossalai/gemini/chunk/chunk.py
colossalai/gemini/chunk/chunk.py
+576
-552
No files found.
Too many changes to show.
To preserve performance only
461 of 461+
files are displayed.
Plain diff
Email patch
colossalai/fx/tracer/meta_patch/patched_function/__init__.py
View file @
e532679c
from
.activation_function
import
*
from
.arithmetic
import
*
from
.convolution
import
*
from
.embedding
import
*
from
.normalization
import
*
from
.python_ops
import
*
from
.torch_ops
import
*
from
.convolution
import
*
\ No newline at end of file
colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
View file @
e532679c
import
torch
from
..registry
import
meta_patched_function
from
...registry
import
meta_patched_function
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
relu
)
...
...
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
View file @
e532679c
import
torch
from
..registry
import
meta_patched_function
from
..
.
registry
import
meta_patched_function
@
meta_patched_function
.
register
(
torch
.
matmul
)
...
...
@@ -57,16 +57,36 @@ def torch_bmm(input, mat2, *, out=None):
return
torch
.
empty
(
batch_size
,
n
,
p
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
linear
)
def
torch_linear
(
input
,
mat2
,
bias
=
None
,
*
,
out
=
None
):
if
out
is
not
None
:
raise
ValueError
(
"Don't support in-place abs for MetaTensor analysis"
)
output_shape
=
list
(
input
.
shape
)
output_feature
=
list
(
mat2
.
shape
)[
0
]
output_shape
[
-
1
]
=
output_feature
return
torch
.
empty
(
*
output_shape
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
addbmm
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
addbmm
)
def
torch_addbmm
(
input
,
mat1
,
mat2
,
*
,
beta
=
1
,
alpha
=
1
,
out
=
None
):
if
out
is
not
None
:
raise
ValueError
(
"Don't support in-place abs for MetaTensor analysis"
)
batch_size
,
n
,
m
=
mat1
.
shape
_
,
n
,
_
=
mat1
.
shape
_
,
_
,
p
=
mat2
.
shape
return
torch
.
empty
(
n
,
p
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
addmm
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
addmm
)
def
torch_addmm
(
input
,
mat1
,
mat2
,
*
,
beta
=
1
,
alpha
=
1
,
out
=
None
):
if
out
is
not
None
:
raise
ValueError
(
"Don't support in-place abs for MetaTensor analysis"
)
n
,
_
=
mat1
.
shape
_
,
p
=
mat2
.
shape
return
torch
.
empty
(
n
,
p
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
var_mean
)
def
torch_var_mean
(
input
,
dim
,
unbiased
=
True
,
keepdim
=
False
,
*
,
out
=
None
):
assert
out
is
None
,
'saving to out is not supported yet'
...
...
colossalai/fx/tracer/meta_patch/patched_function/convolution.py
View file @
e532679c
import
torch
import
collections
from
itertools
import
repeat
from
..registry
import
meta_patched_function
import
math
from
itertools
import
repeat
import
torch
from
...registry
import
meta_patched_function
def
_ntuple
(
n
,
name
=
"parse"
):
...
...
colossalai/fx/tracer/meta_patch/patched_function/embedding.py
View file @
e532679c
import
torch
from
..registry
import
meta_patched_function
from
...registry
import
meta_patched_function
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
embedding
)
...
...
colossalai/fx/tracer/meta_patch/patched_function/normalization.py
View file @
e532679c
import
torch
from
..registry
import
meta_patched_function
from
...registry
import
meta_patched_function
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
layer_norm
)
...
...
colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
View file @
e532679c
import
operator
import
torch
from
..registry
import
meta_patched_function
from
colossalai.fx.proxy
import
ColoProxy
from
...registry
import
meta_patched_function
@
meta_patched_function
.
register
(
operator
.
getitem
)
def
operator_getitem
(
a
,
b
):
...
...
colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
View file @
e532679c
import
torch
from
..registry
import
meta_patched_function
from
...registry
import
meta_patched_function
@
meta_patched_function
.
register
(
torch
.
arange
)
...
...
colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
View file @
e532679c
import
torch
from
..registry
import
meta_patched_module
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
ReLU
)
...
...
colossalai/fx/tracer/meta_patch/patched_module/convolution.py
View file @
e532679c
import
math
import
torch
from
..registry
import
meta_patched_module
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
Conv1d
)
...
...
colossalai/fx/tracer/meta_patch/patched_module/embedding.py
View file @
e532679c
import
torch
from
..registry
import
meta_patched_module
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
Embedding
)
...
...
colossalai/fx/tracer/meta_patch/patched_module/linear.py
View file @
e532679c
import
torch
from
..registry
import
meta_patched_module
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
Linear
)
...
...
colossalai/fx/tracer/meta_patch/patched_module/normalization.py
View file @
e532679c
import
torch
from
..registry
import
meta_patched_module
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
LayerNorm
)
...
...
colossalai/fx/tracer/meta_patch/patched_module/pooling.py
View file @
e532679c
import
math
import
torch
from
..registry
import
meta_patched_module
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
AvgPool1d
)
...
...
colossalai/fx/tracer/meta_patch/patched_module/rnn.py
View file @
e532679c
import
torch
from
..registry
import
meta_patched_module
from
typing
import
Optional
import
torch
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
GRU
)
@
meta_patched_module
.
register
(
torch
.
nn
.
RNN
)
...
...
colossalai/fx/tracer/
meta_patch/
registry.py
→
colossalai/fx/tracer/registry.py
View file @
e532679c
...
...
@@ -23,3 +23,6 @@ class PatchRegistry:
meta_patched_function
=
PatchRegistry
(
name
=
'patched_functions_for_meta_execution'
)
meta_patched_module
=
PatchRegistry
(
name
=
'patched_modules_for_meta_execution'
)
bias_addition_function
=
PatchRegistry
(
name
=
'patched_function_for_bias_addition'
)
bias_addition_module
=
PatchRegistry
(
name
=
'patched_module_for_bias_addition'
)
bias_addition_method
=
PatchRegistry
(
name
=
'patched_method_for_bias_addition'
)
colossalai/fx/tracer/tracer.py
View file @
e532679c
...
...
@@ -5,22 +5,29 @@ tracer.py:
The implementation is partly inspired HuggingFace's fx tracer
"""
import
enum
import
inspect
import
functools
import
inspect
import
operator
from
contextlib
import
contextmanager
from
colossalai.fx.tracer.meta_patch
import
meta_patched_module
from
typing
import
Any
,
Dict
,
Optional
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch.fx
import
Tracer
,
Node
from
torch.fx.graph
import
Graph
from
torch.fx.proxy
import
Proxy
,
ParameterProxy
from
torch.fx
import
Node
,
Tracer
from
torch.fx.graph
import
Graph
,
magic_methods
,
reflectable_magic_methods
from
torch.fx.proxy
import
ParameterProxy
,
Proxy
from
..proxy
import
ColoProxy
from
typing
import
Optional
,
Dict
,
Any
from
._tracer_utils
import
is_element_in_list
,
extract_meta
,
compute_meta_data_for_functions_proxy
from
.meta_patch
import
meta_patched_function
,
meta_patched_module
from
torch.fx.graph
import
magic_methods
,
reflectable_magic_methods
from
._tracer_utils
import
compute_meta_data_for_functions_proxy
,
extract_meta
,
is_element_in_list
from
.bias_addition_patch
import
func_to_func_dict
,
method_to_func_dict
,
module_to_func_dict
from
.registry
import
(
bias_addition_function
,
bias_addition_method
,
bias_addition_module
,
meta_patched_function
,
meta_patched_module
,
)
__all__
=
[
'ColoTracer'
]
...
...
@@ -77,54 +84,42 @@ class ColoTracer(Tracer):
"""
Create a proxy for different kinds of operations.
"""
proxy
=
super
().
create_proxy
(
kind
,
target
,
args
,
kwargs
,
name
,
type_expr
,
proxy_factory_fn
)
if
self
.
tracer_type
==
TracerType
.
DEFAULT
:
# since meta_args is not given
# we just fall back to the original torch.fx.Tracer
proxy
=
super
().
create_proxy
(
kind
,
target
,
args
,
kwargs
,
name
,
type_expr
,
proxy_factory_fn
)
return
proxy
proxy
:
ColoProxy
if
kind
==
"placeholder"
and
target
in
self
.
meta_args
and
self
.
meta_args
[
target
].
is_meta
:
proxy
.
meta_data
=
self
.
meta_args
[
target
]
return
proxy
if
target
in
self
.
orig_torch_tensor_methods
:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if
"device"
in
kwargs
:
kwargs
[
"device"
]
=
"meta"
try
:
args_metas
,
kwargs_metas
=
extract_meta
(
*
args
,
**
kwargs
)
# if graph is traced for auto parallelism module, some extra node will be added during
# graph construction to deal with the compatability between bias addition and all reduce.
# if no extra manipulation is applied, we just pass the origin arguments to create_proxy function
# to create node on computation graph
origin_arguments
=
(
kind
,
target
,
args
,
kwargs
,
name
,
type_expr
,
proxy_factory_fn
)
# dispatch the arguments generator depending on the kind and target in origin arguments.
args_metas
,
_
=
extract_meta
(
*
args
,
**
kwargs
)
handle
=
None
if
kind
==
"call_function"
:
# fetch patched function
if
meta_patched_function
.
has
(
target
):
meta_target
=
meta_patched_function
.
get
(
target
)
elif
meta_patched_function
.
has
(
target
.
__name__
):
# use name for some builtin op like @ (matmul)
meta_target
=
meta_patched_function
.
get
(
target
.
__name__
)
if
bias_addition_function
.
has
(
target
):
if
target
==
torch
.
nn
.
functional
.
linear
:
if
'bias'
in
kwargs
and
kwargs
[
'bias'
]
is
not
None
:
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
else
:
meta_target
=
target
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
elif
bias_addition_function
.
has
(
target
.
__name__
):
# use name for some builtin op like @ (matmul)
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
.
__name__
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
meta_out
=
meta_target
(
*
args_metas
,
**
kwargs_metas
)
if
isinstance
(
meta_out
,
torch
.
Tensor
):
meta_out
=
meta_out
.
to
(
device
=
"meta"
)
elif
kind
==
"call_method"
:
method
=
getattr
(
args_metas
[
0
].
__class__
,
target
)
if
bias_addition_method
.
has
(
method
):
function_to_substitute
=
method_to_func_dict
[
method
]
handle
=
bias_addition_method
.
get
(
method
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
# fetch patched method
if
meta_patched_function
.
has
(
method
):
meta_target
=
meta_patched_function
.
get
(
method
)
else
:
meta_target
=
method
meta_out
=
meta_target
(
*
args_metas
,
**
kwargs_metas
)
elif
kind
==
"call_module"
:
if
not
hasattr
(
self
,
"orig_forward"
):
raise
AttributeError
(
f
"
{
self
}
does not have an attribute called orig_forward"
)
...
...
@@ -132,33 +127,26 @@ class ColoTracer(Tracer):
try
:
mod
=
self
.
root
.
get_submodule
(
target
)
mod_type
=
type
(
mod
)
if
meta_patched_module
.
has
(
mod_type
):
meta_out
=
meta_patched_module
.
get
(
mod_type
)(
mod
,
*
args_metas
,
**
kwargs_metas
)
else
:
meta_out
=
self
.
orig_forward
(
*
args_metas
,
**
kwargs_metas
)
if
bias_addition_module
.
has
(
mod_type
)
and
mod
.
bias
is
not
None
:
function_to_substitute
=
module_to_func_dict
[
mod_type
]
handle
=
bias_addition_module
.
get
(
mod_type
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
finally
:
self
.
_disable_module_getattr
=
False
elif
kind
==
"get_attr"
:
self
.
_disable_module_getattr
=
True
try
:
attr_itr
=
self
.
root
atoms
=
target
.
split
(
"."
)
for
atom
in
atoms
:
attr_itr
=
getattr
(
attr_itr
,
atom
)
if
isinstance
(
attr_itr
,
torch
.
Tensor
):
meta_out
=
attr_itr
.
to
(
device
=
"meta"
)
else
:
meta_out
=
attr_itr
finally
:
self
.
_disable_module_getattr
=
False
else
:
return
proxy
if
not
isinstance
(
proxy
,
Proxy
):
raise
ValueError
(
"Don't support composite output yet"
)
if
handle
is
not
None
:
return
handle
.
generate
()
# create nodes using patched arguments
proxy
=
super
().
create_proxy
(
*
origin_arguments
)
proxy
:
ColoProxy
meta_out
=
self
.
_meta_data_computing
(
kind
,
target
,
args
,
kwargs
,
)
proxy
.
meta_data
=
meta_out
except
Exception
as
e
:
raise
RuntimeError
(
f
"Could not compute metadata for
{
kind
}
target
{
target
}
:
{
e
}
"
)
return
proxy
def
_module_getattr
(
self
,
attr
,
attr_val
,
parameter_proxy_cache
):
...
...
@@ -222,6 +210,105 @@ class ColoTracer(Tracer):
else
:
raise
ValueError
(
f
"Unrecognised tracer type
{
tracer_type
}
"
)
def
_meta_data_computing
(
self
,
kind
,
target
,
args
,
kwargs
):
if
kind
==
"placeholder"
and
target
in
self
.
meta_args
and
self
.
meta_args
[
target
].
is_meta
:
meta_out
=
self
.
meta_args
[
target
]
return
meta_out
if
target
in
self
.
orig_torch_tensor_methods
:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if
"device"
in
kwargs
:
kwargs
[
"device"
]
=
"meta"
try
:
args_metas
,
kwargs_metas
=
extract_meta
(
*
args
,
**
kwargs
)
if
kind
==
"call_function"
:
# Our meta data will not record the nn.parameter.Parameter attribute。
# It works fine in most of the case, but it may cause some problems after
# the bias addition manipulation.
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
# added by the bias addition manipulation following the get_attr node.
convert_to_parameter
=
False
if
target
in
(
torch
.
transpose
,
torch
.
reshape
)
and
isinstance
(
args_metas
[
0
],
torch
.
nn
.
parameter
.
Parameter
):
convert_to_parameter
=
True
# fetch patched function
if
meta_patched_function
.
has
(
target
):
meta_target
=
meta_patched_function
.
get
(
target
)
elif
meta_patched_function
.
has
(
target
.
__name__
):
# use name for some builtin op like @ (matmul)
meta_target
=
meta_patched_function
.
get
(
target
.
__name__
)
else
:
meta_target
=
target
meta_out
=
meta_target
(
*
args_metas
,
**
kwargs_metas
)
if
isinstance
(
meta_out
,
torch
.
Tensor
):
meta_out
=
meta_out
.
to
(
device
=
"meta"
)
if
convert_to_parameter
:
meta_out
=
torch
.
nn
.
Parameter
(
meta_out
)
elif
kind
==
"call_method"
:
# Our meta data will not record the nn.parameter.Parameter attribute。
# It works fine in most of the case, but it may cause some problems after
# the bias addition manipulation.
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
# added by the bias addition manipulation following the get_attr node.
convert_to_parameter
=
False
if
target
in
(
torch
.
Tensor
.
view
,)
and
isinstance
(
args_metas
[
0
],
torch
.
nn
.
parameter
.
Parameter
):
convert_to_parameter
=
True
method
=
getattr
(
args_metas
[
0
].
__class__
,
target
)
# fetch patched method
if
meta_patched_function
.
has
(
method
):
meta_target
=
meta_patched_function
.
get
(
method
)
else
:
meta_target
=
method
meta_out
=
meta_target
(
*
args_metas
,
**
kwargs_metas
)
if
convert_to_parameter
:
meta_out
=
torch
.
nn
.
Parameter
(
meta_out
)
elif
kind
==
"call_module"
:
if
not
hasattr
(
self
,
"orig_forward"
):
raise
AttributeError
(
f
"
{
self
}
does not have an attribute called orig_forward"
)
self
.
_disable_module_getattr
=
True
try
:
mod
=
self
.
root
.
get_submodule
(
target
)
mod_type
=
type
(
mod
)
if
meta_patched_module
.
has
(
mod_type
):
meta_out
=
meta_patched_module
.
get
(
mod_type
)(
mod
,
*
args_metas
,
**
kwargs_metas
)
else
:
meta_out
=
self
.
orig_forward
(
*
args_metas
,
**
kwargs_metas
)
finally
:
self
.
_disable_module_getattr
=
False
elif
kind
==
"get_attr"
:
self
.
_disable_module_getattr
=
True
try
:
attr_itr
=
self
.
root
atoms
=
target
.
split
(
"."
)
for
atom
in
atoms
:
attr_itr
=
getattr
(
attr_itr
,
atom
)
if
isinstance
(
attr_itr
,
torch
.
nn
.
parameter
.
Parameter
):
meta_out
=
torch
.
nn
.
Parameter
(
attr_itr
.
to
(
device
=
"meta"
))
elif
isinstance
(
attr_itr
,
torch
.
Tensor
):
meta_out
=
attr_itr
.
to
(
device
=
"meta"
)
else
:
meta_out
=
attr_itr
finally
:
self
.
_disable_module_getattr
=
False
else
:
return
None
except
Exception
as
e
:
raise
RuntimeError
(
f
"Could not compute metadata for
{
kind
}
target
{
target
}
:
{
e
}
"
)
return
meta_out
def
trace
(
self
,
root
:
nn
.
Module
,
concrete_args
:
Optional
[
Dict
[
str
,
Tensor
]]
=
None
,
...
...
@@ -383,7 +470,7 @@ class ColoTracer(Tracer):
if
self
.
inside_torch_checkpoint_func
:
# annotate the activation checkpoint module
setattr
(
node
,
'activation_checkpoint'
,
self
.
act_ckpt_region_count
)
node
.
meta
[
'activation_checkpoint'
]
=
self
.
act_ckpt_region_count
return
node
...
...
colossalai/gemini/__init__.py
View file @
e532679c
from
.chunk
import
TensorInfo
,
TensorState
from
.chunk
import
ChunkManager
,
TensorInfo
,
TensorState
,
search_chunk_configuration
from
.gemini_mgr
import
GeminiManager
from
.stateful_tensor_mgr
import
StatefulTensorMgr
from
.tensor_placement_policy
import
TensorPlacementPolicyFactory
from
.gemini_mgr
import
GeminiManager
__all__
=
[
'StatefulTensorMgr'
,
'TensorPlacementPolicyFactory'
,
'GeminiManager'
,
'TensorInfo'
,
'TensorState'
]
__all__
=
[
'StatefulTensorMgr'
,
'TensorPlacementPolicyFactory'
,
'GeminiManager'
,
'TensorInfo'
,
'TensorState'
,
'ChunkManager'
,
'search_chunk_configuration'
]
colossalai/gemini/chunk/__init__.py
View file @
e532679c
from
.chunk
import
Chunk
,
ChunkFullError
,
TensorInfo
,
TensorState
from
.manager
import
ChunkManager
from
.search_utils
import
clasify_params
,
search_chunk_configuration
from
.search_utils
import
clas
s
ify_params
_by_dp_degree
,
search_chunk_configuration
from
.utils
import
init_chunk_manager
__all__
=
[
'Chunk'
,
'ChunkManager'
,
'classify_params_by_dp_degree'
,
'search_chunk_configuration'
,
'init_chunk_manager'
]
colossalai/gemini/chunk/chunk.py
View file @
e532679c
import
torch
import
torch.distributed
as
dist
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Optional
,
Dict
,
List
from
typing
import
Dict
,
List
,
Optional
import
torch
import
torch.distributed
as
dist
from
colossalai.utils
import
get_current_device
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.utils
import
get_current_device
class
TensorState
(
Enum
):
...
...
@@ -17,9 +18,9 @@ class TensorState(Enum):
STATE_TRANS
=
((
TensorState
.
FREE
,
TensorState
.
HOLD
),
(
TensorState
.
FREE
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD
,
TensorState
.
FREE
),
(
TensorState
.
HOLD
,
TensorState
.
COMPUTE
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD
_AFTER_BWD
),
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD
,
TensorState
.
FREE
),
(
TensorState
.
HOLD
,
TensorState
.
COMPUTE
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD_AFTER_BWD
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
READY_FOR_REDUCE
,
TensorState
.
HOLD
))
...
...
@@ -50,7 +51,6 @@ def alloc_storage(tensor: torch.Tensor) -> None:
class
Chunk
:
_total_number
=
0
def
__init__
(
self
,
...
...
@@ -58,6 +58,7 @@ class Chunk:
process_group
:
ColoProcessGroup
,
dtype
:
torch
.
dtype
,
init_device
:
Optional
[
torch
.
device
]
=
None
,
cpu_shard_init
:
bool
=
False
,
keep_gathered
:
bool
=
False
,
pin_memory
:
bool
=
False
)
->
None
:
"""
...
...
@@ -70,8 +71,9 @@ class Chunk:
chunk_size (int): the number of elements in the chunk
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional,
the device
where the tensor is
initializ
ed
init_device (torch.device): optional,
During the chunk construction process,
where the tensor is
stor
ed
.
The default value is None, which is the current GPU
cpu_shard_init (bool): a flag indicates the local chunk shard is resident on CPU.
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
"""
...
...
@@ -80,13 +82,12 @@ class Chunk:
self
.
chunk_size
=
chunk_size
self
.
utilized_size
=
0
# Here, we use torch process group,
# since ColoProcessGroup might get deprecated soon
self
.
torch_pg
=
process_group
.
dp_process_group
()
self
.
pg_size
=
dist
.
get_world_size
(
self
.
torch_pg
)
self
.
pg_rank
=
dist
.
get_rank
(
self
.
torch_pg
)
# the chunk size should be
able to be divied by the size of GPU
# the chunk size should be
divisible by the dp degree
if
not
keep_gathered
:
assert
chunk_size
%
self
.
pg_size
==
0
self
.
shard_size
=
chunk_size
//
self
.
pg_size
...
...
@@ -96,26 +97,41 @@ class Chunk:
self
.
dtype
=
dtype
device
=
init_device
or
get_current_device
()
# chunk_temp is a global chunk, which only exists during building the chunks.
self
.
chunk_temp
=
torch
.
zeros
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
# keep all zero
self
.
chunk_total
=
None
# we force chunk_total located in CUDA
self
.
cuda_shard
=
None
# using two attributes for the better interpretation
self
.
cuda_global_chunk
=
None
# we force cuda_global_chunk located in CUDA
# cuda local chunk, which is sharded on GPUs
self
.
cuda_shard
=
None
# cpu local chunk, which is sharded on CPUs
self
.
cpu_shard
=
None
# is the chunks gathers, which means chunks are duplicated on each process,
# and we should use the cuda_global_chunk.
self
.
is_gathered
=
True
# configure the init device of the shard
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
self
.
shard_device
=
torch
.
device
(
"cpu"
)
if
cpu_shard_init
else
get_current_device
()
self
.
chunk_mem
=
self
.
chunk_size
*
self
.
chunk_temp
.
element_size
()
self
.
shard_mem
=
self
.
chunk_mem
//
self
.
pg_size
# each tensor is associated with a TensorInfo to track meta info
# each tensor is associated with a TensorInfo to track its meta info
# (state, offset, end)
self
.
tensors_info
:
Dict
[
torch
.
Tensor
,
TensorInfo
]
=
{}
# the total number of
all
tensors
# the total number of tensors
in the chunk
self
.
num_tensors
=
0
# monitor the states of all tensors
self
.
tensors_state_monitor
:
Dict
[
TensorState
,
int
]
=
dict
()
# Record the number of tensors in different states
self
.
tensor_state_cnter
:
Dict
[
TensorState
,
int
]
=
dict
()
for
state
in
TensorState
:
self
.
tensor
s
_state_
monito
r
[
state
]
=
0
self
.
tensor_state_
cnte
r
[
state
]
=
0
#
some
chunk
s can
ke
e
p gathered
all the time
#
so their computation patterns are
the same as that of the parameters in DDP
#
If a
chunk
is
kep
t
gathered
,
#
they are treated
the same as that of the parameters in DDP
during training.
self
.
keep_gathered
=
keep_gathered
if
self
.
keep_gathered
:
pin_memory
=
False
# since this chunk is gathered, it doesn't need to pin
...
...
@@ -133,6 +149,10 @@ class Chunk:
# if the cpu_shard has been visited during the training step, the flag is True
self
.
cpu_vis_flag
=
False
# whether to record l2 norm for the gradient clipping calculation
self
.
l2_norm_flag
=
False
self
.
l2_norm
=
None
@
property
def
memory_usage
(
self
)
->
Dict
[
str
,
int
]:
cuda_memory
=
0
...
...
@@ -172,7 +192,7 @@ class Chunk:
assert
self
.
chunk_temp
is
None
if
self
.
is_gathered
:
return
self
.
c
hunk_total
return
self
.
c
uda_global_chunk
elif
self
.
cuda_shard
is
not
None
:
return
self
.
cuda_shard
else
:
...
...
@@ -197,25 +217,37 @@ class Chunk:
if
self
.
keep_gathered
:
return
False
else
:
return
self
.
tensor
s
_state_
monito
r
[
TensorState
.
HOLD
]
+
\
self
.
tensor
s
_state_
monito
r
[
TensorState
.
HOLD_AFTER_BWD
]
==
self
.
num_tensors
return
self
.
tensor_state_
cnte
r
[
TensorState
.
HOLD
]
+
\
self
.
tensor_state_
cnte
r
[
TensorState
.
HOLD_AFTER_BWD
]
==
self
.
num_tensors
@
property
def
can_reduce
(
self
):
return
self
.
tensor
s
_state_
monito
r
[
TensorState
.
READY_FOR_REDUCE
]
==
self
.
num_tensors
return
self
.
tensor_state_
cnte
r
[
TensorState
.
READY_FOR_REDUCE
]
==
self
.
num_tensors
@
property
def
has_inf_or_nan
(
self
)
->
bool
:
"""Check if the chunk has inf or nan values
i
n CUDA.
"""Check if the chunk has inf or nan values
o
n CUDA.
"""
if
self
.
is_gathered
:
valid_tensor
=
self
.
c
hunk_total
[:
self
.
utilized_size
]
valid_tensor
=
self
.
c
uda_global_chunk
[:
self
.
utilized_size
]
else
:
assert
self
.
cuda_shard
is
not
None
# only check
i
n CUDA
assert
self
.
cuda_shard
is
not
None
# only check
o
n CUDA
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
return
torch
.
isinf
(
valid_tensor
).
any
().
item
()
|
torch
.
isnan
(
valid_tensor
).
any
().
item
()
def
set_l2_norm
(
self
)
->
None
:
"""Record l2 norm of this chunks on CUDA.
"""
assert
self
.
l2_norm
is
None
,
"you are calculating the l2 norm twice"
if
self
.
is_gathered
:
valid_tensor
=
self
.
cuda_global_chunk
[:
self
.
utilized_size
]
else
:
assert
self
.
cuda_shard
is
not
None
# calculate on CUDA
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
chunk_l2_norm
=
valid_tensor
.
data
.
float
().
norm
(
2
)
self
.
l2_norm
=
chunk_l2_norm
.
item
()
**
2
def
append_tensor
(
self
,
tensor
:
torch
.
Tensor
):
"""Add a tensor to the chunk.
...
...
@@ -239,14 +271,11 @@ class Chunk:
self
.
num_tensors
+=
1
tensor_state
=
TensorState
.
HOLD
self
.
tensors_info
[
tensor
]
=
TensorInfo
(
tensor_state
,
self
.
utilized_size
,
new_utilized_size
)
self
.
tensor
s
_state_
monito
r
[
tensor_state
]
+=
1
self
.
tensor_state_
cnte
r
[
tensor_state
]
+=
1
self
.
utilized_size
=
new_utilized_size
def
close_chunk
(
self
,
shard_dev
:
Optional
[
torch
.
device
]
=
None
):
def
close_chunk
(
self
):
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
Args:
shard_dev: the device where the shard locates
"""
# sanity check
assert
self
.
chunk_temp
is
not
None
...
...
@@ -258,28 +287,23 @@ class Chunk:
self
.
valid_end
=
self
.
utilized_size
-
self
.
shard_begin
if
self
.
chunk_temp
.
device
.
type
==
'cpu'
:
self
.
c
hunk_total
=
self
.
chunk_temp
.
to
(
get_current_device
())
self
.
c
uda_global_chunk
=
self
.
chunk_temp
.
to
(
get_current_device
())
self
.
__update_tensors_ptr
()
else
:
self
.
c
hunk_total
=
self
.
chunk_temp
self
.
c
uda_global_chunk
=
self
.
chunk_temp
self
.
chunk_temp
=
None
self
.
__scatter
()
# gathered chunk never have shard attribute
if
self
.
keep_gathered
:
if
shard_dev
is
None
:
shard_dev
=
get_current_device
()
else
:
assert
shard_dev
.
type
==
'cuda'
elif
shard_dev
is
None
:
shard_dev
=
torch
.
device
(
'cpu'
)
return
if
self
.
pin_memory
or
shard_dev
.
type
==
'cpu'
:
if
self
.
pin_memory
or
self
.
shard_dev
ice
.
type
==
'cpu'
:
self
.
cpu_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
pin_memory
=
self
.
pin_memory
)
self
.
cpu_shard
.
copy_
(
self
.
cuda_shard
)
self
.
cpu_vis_flag
=
True
# cpu_shard has been visited
if
shard_dev
.
type
==
'cpu'
:
if
self
.
shard_dev
ice
.
type
==
'cpu'
:
self
.
cuda_shard
=
None
def
shard_move
(
self
,
device
:
torch
.
device
,
force_copy
:
bool
=
False
):
...
...
@@ -352,19 +376,19 @@ class Chunk:
if
self
.
pg_size
==
1
:
# tricky code here
# just move c
hunk_total
to cuda_shard
# just move c
uda_global_chunk
to cuda_shard
# the communication is not necessary
self
.
__scatter
()
elif
self
.
keep_gathered
:
# we use all-reduce here
dist
.
all_reduce
(
self
.
c
hunk_total
,
group
=
self
.
torch_pg
)
dist
.
all_reduce
(
self
.
c
uda_global_chunk
,
group
=
self
.
torch_pg
)
else
:
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
get_current_device
())
input_list
=
list
(
torch
.
chunk
(
self
.
c
hunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
input_list
=
list
(
torch
.
chunk
(
self
.
c
uda_global_chunk
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
reduce_scatter
(
self
.
cuda_shard
,
input_list
,
group
=
self
.
torch_pg
)
free_storage
(
self
.
c
hunk_total
)
free_storage
(
self
.
c
uda_global_chunk
)
self
.
is_gathered
=
False
self
.
__update_tensors_state
(
TensorState
.
HOLD
)
...
...
@@ -399,8 +423,8 @@ class Chunk:
assert
self
.
is_gathered
tensor_info
=
self
.
tensors_info
[
tensor
]
self
.
c
hunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
data
.
flatten
())
tensor
.
data
=
self
.
c
hunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
self
.
c
uda_global_chunk
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
data
.
flatten
())
tensor
.
data
=
self
.
c
uda_global_chunk
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
def
get_valid_length
(
self
)
->
int
:
"""Get the valid length of the chunk's payload.
...
...
@@ -429,7 +453,7 @@ class Chunk:
friend_chunk
=
self
.
paired_chunk
if
self
.
is_gathered
is
True
:
assert
friend_chunk
.
is_gathered
is
True
self
.
c
hunk_total
.
copy_
(
friend_chunk
.
c
hunk_total
)
self
.
c
uda_global_chunk
.
copy_
(
friend_chunk
.
c
uda_global_chunk
)
self
.
optim_sync_flag
=
True
elif
friend_chunk
.
device_type
==
'cuda'
and
self
.
device_type
==
'cuda'
:
self
.
cuda_shard
.
copy_
(
friend_chunk
.
cuda_shard
)
...
...
@@ -451,8 +475,8 @@ class Chunk:
# sanity check
assert
self
.
cuda_shard
is
not
None
alloc_storage
(
self
.
c
hunk_total
)
gather_list
=
list
(
torch
.
chunk
(
input
=
self
.
c
hunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
alloc_storage
(
self
.
c
uda_global_chunk
)
gather_list
=
list
(
torch
.
chunk
(
input
=
self
.
c
uda_global_chunk
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
all_gather
(
gather_list
,
self
.
cuda_shard
,
self
.
torch_pg
)
self
.
cuda_shard
=
None
...
...
@@ -466,11 +490,11 @@ class Chunk:
# sanity check
assert
self
.
cuda_shard
is
None
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
self
.
c
hunk_total
.
device
)
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
self
.
c
uda_global_chunk
.
device
)
self
.
cuda_shard
.
copy_
(
self
.
c
hunk_total
[
self
.
shard_begin
:
self
.
shard_end
])
self
.
cuda_shard
.
copy_
(
self
.
c
uda_global_chunk
[
self
.
shard_begin
:
self
.
shard_end
])
free_storage
(
self
.
c
hunk_total
)
free_storage
(
self
.
c
uda_global_chunk
)
self
.
is_gathered
=
False
def
__paired_shard_move
(
self
):
...
...
@@ -491,15 +515,15 @@ class Chunk:
def
__update_tensors_ptr
(
self
)
->
None
:
# sanity check
assert
self
.
is_gathered
assert
type
(
self
.
c
hunk_total
)
==
torch
.
Tensor
assert
type
(
self
.
c
uda_global_chunk
)
==
torch
.
Tensor
for
tensor
,
tensor_info
in
self
.
tensors_info
.
items
():
tensor
.
data
=
self
.
c
hunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
tensor
.
data
=
self
.
c
uda_global_chunk
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
def
__update_one_tensor_info
(
self
,
tensor_info
:
TensorInfo
,
next_state
:
TensorState
):
self
.
tensor
s
_state_
monito
r
[
tensor_info
.
state
]
-=
1
self
.
tensor_state_
cnte
r
[
tensor_info
.
state
]
-=
1
tensor_info
.
state
=
next_state
self
.
tensor
s
_state_
monito
r
[
tensor_info
.
state
]
+=
1
self
.
tensor_state_
cnte
r
[
tensor_info
.
state
]
+=
1
def
__update_tensors_state
(
self
,
next_state
:
TensorState
,
prev_state
:
Optional
[
TensorState
]
=
None
):
for
tensor_info
in
self
.
tensors_info
.
values
():
...
...
@@ -529,9 +553,9 @@ class Chunk:
output
.
append
(
"
\t
chunk temp:
\n
"
)
print_tensor
(
tensor
=
self
.
chunk_temp
,
prefix
=
'
\t\t
'
)
if
self
.
c
hunk_total
is
not
None
and
self
.
c
hunk_total
.
storage
().
size
()
>
0
:
if
self
.
c
uda_global_chunk
is
not
None
and
self
.
c
uda_global_chunk
.
storage
().
size
()
>
0
:
output
.
append
(
"
\t
chunk total:
\n
"
)
print_tensor
(
tensor
=
self
.
c
hunk_total
,
prefix
=
'
\t\t
'
)
print_tensor
(
tensor
=
self
.
c
uda_global_chunk
,
prefix
=
'
\t\t
'
)
if
self
.
cuda_shard
is
not
None
:
output
.
append
(
"
\t
cuda shard:
\n
"
)
...
...
@@ -547,6 +571,6 @@ class Chunk:
if
detailed
:
output
.
append
(
"
\t
tensor state monitor:
\n
"
)
for
st
in
TensorState
:
output
.
append
(
"
\t\t
# of {}: {}
\n
"
.
format
(
st
,
self
.
tensor
s
_state_
monito
r
[
st
]))
output
.
append
(
"
\t\t
# of {}: {}
\n
"
.
format
(
st
,
self
.
tensor_state_
cnte
r
[
st
]))
return
''
.
join
(
output
)
Prev
1
…
6
7
8
9
10
11
12
13
14
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment