Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ab95ec9a
Unverified
Commit
ab95ec9a
authored
May 06, 2022
by
Jiarui Fang
Committed by
GitHub
May 06, 2022
Browse files
[Tensor] init ColoParameter (#914)
parent
193d6293
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
77 additions
and
44 deletions
+77
-44
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+2
-1
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+28
-0
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+17
-30
colossalai/tensor/const.py
colossalai/tensor/const.py
+6
-0
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+3
-6
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+21
-7
No files found.
colossalai/tensor/__init__.py
View file @
ab95ec9a
...
...
@@ -2,11 +2,12 @@ from .spec import ComputePattern, ParallelAction, TensorSpec, ShardPattern
from
.op_wrapper
import
(
colo_op_impl
,)
from
.colo_tensor
import
ColoTensor
from
.colo_parameter
import
ColoParameter
from
.utils
import
convert_parameter
,
named_params_with_colotensor
from
._ops
import
*
from
.optim.colo_optimizer
import
ColoOptimizer
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'colo_op_impl'
,
'ComputePattern'
,
'TensorSpec'
,
'ParallelAction'
,
'named_params_with_colotensor'
,
'ShardPattern'
,
'ColoOptimizer'
'named_params_with_colotensor'
,
'ShardPattern'
,
'ColoOptimizer'
,
'ColoParameter'
]
colossalai/tensor/colo_parameter.py
0 → 100644
View file @
ab95ec9a
from
.colo_tensor
import
ColoTensor
from
.const
import
TensorType
import
torch
class
ColoParameter
(
ColoTensor
):
r
"""A kind of ColoTensor to be considered as a module parameter.
"""
def
__init__
(
self
,
*
args
,
**
kargs
):
super
().
__init__
(
*
args
,
**
kargs
)
self
.
_type
=
TensorType
.
MODEL
def
__new__
(
cls
,
*
args
,
**
kwargs
):
t
=
super
(
ColoParameter
,
cls
).
__new__
(
cls
)
t
.
_type
=
TensorType
.
MODEL
return
t
@
staticmethod
def
init_from_torch_tensor
(
tensor
:
torch
.
Tensor
,
save_payload
=
True
)
->
'ColoParameter'
:
colo_p
=
ColoParameter
(
*
tensor
.
size
(),
dtype
=
tensor
.
dtype
,
requires_grad
=
tensor
.
requires_grad
,
pin_memory
=
tensor
.
is_pinned
(),
device
=
tensor
.
device
,
torch_tensor
=
tensor
if
save_payload
else
torch
.
empty
(
0
))
return
colo_p
colossalai/tensor/colo_tensor.py
View file @
ab95ec9a
...
...
@@ -7,12 +7,7 @@ from colossalai.core import global_context as gpc
from
colossalai.nn.layer.utils
import
divide
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ShardPattern
from
colossalai.nn.layer.parallel_1d._utils
import
split_forward_gather_backward
,
gather_forward_split_backward
from
enum
import
Enum
class
TensorType
(
Enum
):
MODEL
=
0
NONMODEL
=
1
# mainly activations
from
.const
import
TensorType
class
ColoTensor
(
object
):
...
...
@@ -26,17 +21,14 @@ class ColoTensor(object):
def
__new__
(
cls
,
*
args
,
**
kwargs
):
return
super
(
ColoTensor
,
cls
).
__new__
(
cls
)
def
__init__
(
self
,
*
size
:
Tuple
[
int
],
dtype
=
None
,
requires_grad
=
False
,
pin_memory
=
False
,
device
=
None
,
torch_tensor
=
torch
.
empty
(
0
),
shard_spec
:
TensorSpec
=
TensorSpec
(),
is_model_data
:
bool
=
False
,
):
def
__init__
(
self
,
*
size
:
Tuple
[
int
],
dtype
=
None
,
requires_grad
=
False
,
pin_memory
=
False
,
device
=
None
,
torch_tensor
=
torch
.
empty
(
0
),
shard_spec
:
TensorSpec
=
TensorSpec
()):
self
.
_size
=
size
self
.
_dtype
=
dtype
self
.
_requires_grad
=
requires_grad
...
...
@@ -45,10 +37,7 @@ class ColoTensor(object):
self
.
_torch_tensor
=
torch_tensor
self
.
_shard_spec
=
shard_spec
self
.
_shard_pattern
=
ShardPattern
.
NA
if
is_model_data
:
self
.
_type
=
TensorType
.
MODEL
else
:
self
.
_type
=
TensorType
.
NONMODEL
self
.
_type
=
TensorType
.
NONMODEL
def
__getitem__
(
self
,
key
):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()[
key
])
...
...
@@ -97,14 +86,13 @@ class ColoTensor(object):
return
product
(
self
.
_size
)
@
staticmethod
def
init_from_torch_tensor
(
tensor
:
torch
.
Tensor
,
save_payload
=
True
,
is_model_data
=
False
)
->
'ColoTensor'
:
def
init_from_torch_tensor
(
tensor
:
torch
.
Tensor
,
save_payload
=
True
)
->
'ColoTensor'
:
colo_t
=
ColoTensor
(
*
tensor
.
size
(),
dtype
=
tensor
.
dtype
,
requires_grad
=
tensor
.
requires_grad
,
pin_memory
=
tensor
.
is_pinned
(),
device
=
tensor
.
device
,
torch_tensor
=
tensor
if
save_payload
else
torch
.
empty
(
0
),
is_model_data
=
is_model_data
)
torch_tensor
=
tensor
if
save_payload
else
torch
.
empty
(
0
))
return
colo_t
def
del_torch_tensor
(
self
,
save_shape
=
False
)
->
None
:
...
...
@@ -143,12 +131,11 @@ class ColoTensor(object):
self
.
gather
()
# Model Parameters
if
self
.
_shard_spec
.
num_action
==
1
:
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
self
.
_shard_spec
.
compute_patterns
[
0
])
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
self
.
_shard_spec
.
compute_patterns
[
0
])
if
parallel_action
.
compute_pattern
in
[
ComputePattern
.
TP1DRow_Linear
,
\
ComputePattern
.
TP1DCol_Embedding
]:
self
.
_shard_1d
(
parallel_action
=
parallel_action
,
dim
=-
1
)
self
.
_shard_pattern
=
ShardPattern
.
Col
# We bind our ComputePattern on weight, which has to be transposed when linear().
self
.
_shard_pattern
=
ShardPattern
.
Col
# We bind our ComputePattern on weight, which has to be transposed when linear().
elif
parallel_action
.
compute_pattern
in
[
ComputePattern
.
TP1DCol_Linear
,
\
ComputePattern
.
TP1DRow_Embedding
]:
self
.
_shard_1d
(
parallel_action
=
parallel_action
,
dim
=
0
)
...
...
@@ -157,7 +144,7 @@ class ColoTensor(object):
raise
NotImplementedError
def
gather
(
self
):
assert
self
.
is_
activation
(),
'Currently we only support gather Activation ColoTensor.'
assert
not
self
.
is_
model_data
(),
'Currently we only support gather Activation ColoTensor.'
assert
not
self
.
is_gathered
(),
'Only sharded ColoTensor can be gathered.'
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
DP
)
if
self
.
_shard_pattern
==
ShardPattern
.
Row
:
...
...
@@ -174,8 +161,8 @@ class ColoTensor(object):
def
has_spec
(
self
)
->
bool
:
return
self
.
_shard_spec
is
not
None
and
self
.
_shard_spec
.
num_action
>
0
def
is_
activation
(
self
)
->
bool
:
return
self
.
_type
==
TensorType
.
NON
MODEL
def
is_
model_data
(
self
)
->
bool
:
return
self
.
_type
==
TensorType
.
MODEL
def
_shard_1d
(
self
,
parallel_action
,
dim
=-
1
):
num_partition
=
gpc
.
get_world_size
(
parallel_action
.
parallel_mode
)
...
...
colossalai/tensor/const.py
0 → 100644
View file @
ab95ec9a
from
enum
import
Enum
class
TensorType
(
Enum
):
MODEL
=
0
NONMODEL
=
1
# mainly activations
colossalai/utils/model/colo_init_context.py
View file @
ab95ec9a
from
.utils
import
InsertPostInitMethodToModuleSubClasses
import
torch
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
import
types
from
torch
import
nn
...
...
@@ -100,10 +100,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
tensor_detached
=
param
.
to
(
self
.
_device
).
detach
()
tensor_detached
.
requires_grad
=
requires_grad
setattr
(
module
,
name
,
ColoTensor
.
init_from_torch_tensor
(
tensor
=
tensor_detached
,
save_payload
=
save_torch_payload
,
is_model_data
=
True
))
setattr
(
module
,
name
,
ColoParameter
.
init_from_torch_tensor
(
tensor
=
tensor_detached
,
save_payload
=
save_torch_payload
))
ColoModulize
(
module
)
tests/test_tensor/test_model.py
View file @
ab95ec9a
...
...
@@ -38,17 +38,23 @@ def run_1d_col_tp():
model
=
model_builder
(
checkpoint
=
True
)
parallel_action_list_row
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_row
=
TensorSpec
(
parallel_action_list_row
)
parallel_action_list_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_col
=
TensorSpec
(
parallel_action_list_col
)
parallel_action_list_embedding_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_embedding_col
=
TensorSpec
(
parallel_action_list_embedding_col
)
...
...
@@ -125,6 +131,9 @@ def test_model_parameters():
param_cnt
+=
1
assert
param_cnt
==
5
for
name
,
colo_p
in
model
.
colo_named_parameters
():
assert
colo_p
.
is_model_data
()
param_cnt
=
0
for
name
,
p
in
model
.
named_parameters
(
recurse
=
False
):
param_cnt
+=
1
...
...
@@ -175,12 +184,16 @@ def run_1d_row_tp():
model
=
model_builder
(
checkpoint
=
True
)
parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec
=
TensorSpec
(
parallel_action_list
)
parallel_action_list_embedding_row
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_embedding_row
=
TensorSpec
(
parallel_action_list_embedding_row
)
...
...
@@ -243,6 +256,7 @@ def run_dist(rank, world_size, port):
run_1d_row_tp
()
run_1d_col_tp
()
@
pytest
.
mark
.
dist
@
parameterize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
...
...
@@ -252,6 +266,6 @@ def test_simple_net(world_size):
if
__name__
==
'__main__'
:
test_simple_net
()
#
test_model_parameters()
#
test_simple_net()
test_model_parameters
()
# test_colo_optimizer()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment