Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8cdce039
Unverified
Commit
8cdce039
authored
Jun 21, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 21, 2022
Browse files
[ColoTensor] improves init functions. (#1150)
parent
8106d7b8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
103 additions
and
40 deletions
+103
-40
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+1
-1
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+38
-17
colossalai/tensor/dist_spec_mgr.py
colossalai/tensor/dist_spec_mgr.py
+10
-0
colossalai/tensor/spec.py
colossalai/tensor/spec.py
+5
-20
tests/test_tensor/test_tensor.py
tests/test_tensor/test_tensor.py
+49
-2
No files found.
colossalai/tensor/colo_parameter.py
View file @
8cdce039
...
...
@@ -35,7 +35,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
data
:
Optional
[
torch
.
Tensor
]
=
None
,
requires_grad
:
bool
=
True
,
spec
:
TensorSpec
=
TensorSpec
(
distspec
.
replicate
()))
->
None
:
self
.
_spec
=
copy
(
spec
)
self
.
_
tensor_
spec
=
copy
(
spec
)
self
.
_type
=
TensorType
.
MODEL
self
.
_graph_node
=
None
...
...
colossalai/tensor/colo_tensor.py
View file @
8cdce039
from
.op_wrapper
import
_COLOSSAL_OPS
from
.const
import
TensorType
from
copy
import
copy
import
torch
from
torch.overrides
import
get_default_nowrap_functions
from
colossalai.tensor
import
TensorSpec
from
.const
import
TensorType
from
colossalai.tensor
import
distspec
from
colossalai.tensor.dist_spec_mgr
import
DistSpecManager
from
colossalai.tensor.distspec
import
_DistSpec
from
torch.overrides
import
get_default_nowrap_functions
def
_convert_output
(
output
):
...
...
@@ -18,34 +19,54 @@ def _convert_output(output):
class
ColoTensor
(
torch
.
Tensor
):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
2. It supports lazy init the tensor's payload.
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate()).
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
1. directly init.
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = TensorSpec(shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
"""
def
__new__
(
cls
,
data
:
torch
.
Tensor
,
spec
:
TensorSpec
=
TensorSpec
(
distspec
.
replicate
()))
->
'ColoTensor'
:
"""__new__
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate())
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
if
data
is
None
:
data
=
torch
.
empty
(
0
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
data
.
requires_grad
)
def
__init__
(
self
,
data
:
torch
.
Tensor
,
spec
:
TensorSpec
=
TensorSpec
(
distspec
.
replicate
()))
->
None
:
self
.
_spec
=
copy
(
spec
)
self
.
_
tensor_
spec
=
copy
(
spec
)
self
.
_type
=
TensorType
.
NONMODEL
self
.
_graph_node
=
None
@
property
def
spec
(
self
)
->
TensorSpec
:
return
self
.
_spec
return
self
.
_
tensor_
spec
def
set_spec
(
self
,
spec
:
TensorSpec
)
->
None
:
spec
=
copy
(
spec
)
self
.
convert_to_dist_spec
_
(
spec
.
dist_spec
)
self
.
_spec
=
spec
self
.
_
convert_to_dist_spec
(
spec
.
dist_spec
)
self
.
_
tensor_
spec
=
spec
def
has_spec
(
self
)
->
bool
:
return
self
.
_spec
.
parallel_action
is
not
None
return
self
.
_
tensor_
spec
.
parallel_action
is
not
None
def
is_model_data
(
self
)
->
bool
:
return
self
.
_type
==
TensorType
.
MODEL
...
...
@@ -74,16 +95,16 @@ class ColoTensor(torch.Tensor):
def
is_model_data
(
self
)
->
bool
:
return
self
.
_type
==
TensorType
.
MODEL
def
convert_to_dist_spec
_
(
self
,
dist_spec
:
_DistSpec
)
->
None
:
def
_
convert_to_dist_spec
(
self
,
dist_spec
:
_DistSpec
)
->
None
:
with
DistSpecManager
.
no_grad
():
self
.
data
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
spec
.
dist_spec
,
dist_spec
)
self
.
_spec
.
dist_spec
=
dist_spec
self
.
_
tensor_
spec
.
dist_spec
=
dist_spec
def
convert_to_dist_spec
(
self
,
dist_spec
:
_DistSpec
)
->
'ColoTensor'
:
spec
=
copy
(
self
.
_spec
)
spec
.
dist_spec
=
dist_spec
tensor_
spec
=
copy
(
self
.
_
tensor_
spec
)
tensor_
spec
.
dist_spec
=
dist_spec
ret
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
spec
.
dist_spec
,
dist_spec
)
return
ColoTensor
.
from_torch_tensor
(
ret
,
spec
)
return
ColoTensor
.
from_torch_tensor
(
ret
,
tensor_
spec
)
@
staticmethod
def
from_torch_tensor
(
tensor
:
torch
.
Tensor
,
spec
:
TensorSpec
=
TensorSpec
(
distspec
.
replicate
()))
->
'ColoTensor'
:
...
...
colossalai/tensor/dist_spec_mgr.py
View file @
8cdce039
...
...
@@ -4,6 +4,7 @@ from numpy import prod
from
contextlib
import
contextmanager
import
torch
import
torch.distributed
as
dist
from
packaging
import
version
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
...
...
@@ -56,6 +57,12 @@ class DistSpecManager:
@
staticmethod
def
_gather
(
tensor
:
torch
.
Tensor
,
old_dist_spec
:
_DistSpec
)
->
torch
.
Tensor
:
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.11.0"
):
# pytorch lower than 1.11 dose not support gather a cpu tensor.
# Therefore, we transfer tensor to GPU before gather.
saved_dev
=
tensor
.
device
tensor
.
data
=
tensor
.
data
.
cuda
()
buffer
=
[
torch
.
empty_like
(
tensor
)
for
_
in
range
(
old_dist_spec
.
process_group
.
size
())]
dist
.
all_gather
(
buffer
,
tensor
,
group
=
old_dist_spec
.
process_group
)
for
i
in
range
(
len
(
old_dist_spec
.
dims
)
-
1
,
-
1
,
-
1
):
...
...
@@ -66,6 +73,9 @@ class DistSpecManager:
new_buffer
.
append
(
torch
.
cat
(
buffer
[
start
:
start
+
num_parts
],
dim
))
buffer
=
new_buffer
assert
len
(
buffer
)
==
1
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.11.0"
):
buffer
[
0
].
data
=
buffer
[
0
].
data
.
to
(
saved_dev
)
return
buffer
[
0
]
@
staticmethod
...
...
colossalai/tensor/spec.py
View file @
8cdce039
...
...
@@ -24,28 +24,13 @@ class ParallelAction(object):
class
TensorSpec
(
object
):
"""
It contains two aspects of information:
First, How are tensors distributed in Heterougenous memory space.
Second, if the tensor is a model parameter, the Spec contains the
parallel
computation pattern
o
f
the
Operator (Layer).
We have to consider the hybrid parallel mod
e.
The specification of the ColoTensor.
Args:
dist_spec (_DistSpec): descriping the layout among processes.
parallel
_action (Optional[ParallelAction], optional): actions conducted
o
n
the
tensor after initialization if it's a model data tensor.
Defaults to Non
e.
"""
# a list of parallel actions.
# For example: On 8 GPUs, a hybrid parallel strategy is applied using
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
# parallel_action_list = [
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
# ParallelAction(1, ComputePattern.TP1D_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
# ]
# When the ColoTensor is initialized,
# we first splitting tensor according to ParallelAction of ZeRO,
# then splitting tensor according to ParallelAction of TP1D_Linear.
# During Linear computation
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1D_Linear.
# After Linear Op, we split the tensors according to ZeRO.
def
__init__
(
self
,
dist_spec
:
_DistSpec
,
parallel_action
:
Optional
[
ParallelAction
]
=
None
):
self
.
parallel_action
=
parallel_action
self
.
dist_spec
=
dist_spec
...
...
tests/test_tensor/test_tensor.py
View file @
8cdce039
...
...
@@ -3,6 +3,17 @@ import pytest
from
colossalai.tensor
import
ColoTensor
from
numpy
import
allclose
import
colossalai
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
distspec
,
TensorSpec
from
colossalai.core
import
global_context
as
gpc
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
distspec
,
TensorSpec
,
ColoTensor
from
colossalai.context
import
ParallelMode
from
functools
import
partial
def
test_tensor_indexing
():
torch_t
=
torch
.
randn
(
2
,
3
)
...
...
@@ -25,8 +36,6 @@ def test_wrapped_tensor_func():
# non-func attr
assert
t
.
is_cuda
==
t_ref
.
is_cuda
# TODO I don't find out a tensor function which returns None.
# return 1 torch.Tensor
t_abs
=
t
.
abs
()
assert
isinstance
(
t_abs
,
ColoTensor
)
and
torch
.
equal
(
t_abs
,
t_ref
.
abs
())
...
...
@@ -47,3 +56,41 @@ def test_operand():
t_res
=
t
+
t
assert
torch
.
allclose
(
t_ref_res
,
t_res
)
#### Test Distributed init a Colotensor
def
_run_tensor_shard_init
(
world_size
):
t_ref
=
torch
.
randn
(
4
,
5
)
print
(
gpc
.
get_group
(
ParallelMode
.
DATA
).
size
())
shard_spec
=
distspec
.
shard
(
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
),
dims
=
[
0
],
num_partitions
=
[
world_size
])
tensor_spec
=
TensorSpec
(
shard_spec
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
tensor_spec
)
t
.
set_spec
(
TensorSpec
(
dist_spec
=
distspec
.
replicate
()))
assert
t
.
shape
==
torch
.
Size
((
4
*
world_size
,
5
))
def
_run_tensor_replicated_init
(
world_size
):
t_ref
=
torch
.
randn
(
4
*
world_size
,
5
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
())
assert
t
.
shape
==
torch
.
Size
((
4
*
world_size
,
5
)),
f
"
{
t
.
shape
}
"
def
run_tensor_init
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
_run_tensor_shard_init
(
world_size
)
_run_tensor_replicated_init
(
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
def
_test_dist_init
(
world_size
):
run_func
=
partial
(
run_tensor_init
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
# _test_dist_init(4)
test_new
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment