Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e414e409
Unverified
Commit
e414e409
authored
Mar 01, 2023
by
YuliangLiu0306
Committed by
GitHub
Mar 01, 2023
Browse files
[DTensor] implementation of dtensor (#2946)
* [DTensor] implementation of dtensor * test layout convert * polish
parent
489a9566
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
284 additions
and
0 deletions
+284
-0
colossalai/tensor/d_tensor/d_tensor.py
colossalai/tensor/d_tensor/d_tensor.py
+158
-0
colossalai/tensor/d_tensor/layout.py
colossalai/tensor/d_tensor/layout.py
+22
-0
tests/test_tensor/test_dtensor.py
tests/test_tensor/test_dtensor.py
+104
-0
No files found.
colossalai/tensor/d_tensor/d_tensor.py
0 → 100644
View file @
e414e409
from
typing
import
Optional
import
torch
from
torch.utils._pytree
import
tree_map
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.d_tensor.layout
import
Layout
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
,
to_global
from
colossalai.tensor.sharding_spec
import
ShardingSpec
shape_consistency_manager
=
ShapeConsistencyManager
()
class
DTensor
(
torch
.
Tensor
):
def
__init__
(
self
,
local_tensor
:
torch
.
Tensor
,
dist_layout
:
Layout
):
self
.
local_tensor
=
local_tensor
self
.
data_type
=
local_tensor
.
dtype
self
.
entire_shape
=
local_tensor
.
shape
if
dist_layout
.
entire_shape
is
None
:
dist_layout
.
entire_shape
=
self
.
entire_shape
self
.
dist_layout
=
dist_layout
self
.
_apply_layout
()
@
staticmethod
def
__new__
(
cls
,
local_tensor
,
layout
):
return
torch
.
Tensor
.
_make_subclass
(
cls
,
local_tensor
,
local_tensor
.
requires_grad
)
def
__repr__
(
self
):
return
f
"DTensor(
{
self
.
to_global
()
}
,
{
self
.
dist_layout
}
)"
def
__str__
(
self
):
return
self
.
__repr__
()
def
layout_convert
(
self
,
target_layout
):
'''
Convert the layout of the tensor from source_spec to target_spec.
'''
source_spec
=
convert_layout_to_sharding_spec
(
self
.
dist_layout
)
target_spec
=
convert_layout_to_sharding_spec
(
target_layout
)
self
.
local_tensor
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
self
.
local_tensor
,
source_spec
,
target_spec
)
self
.
dist_layout
=
target_layout
def
_apply_layout
(
self
):
'''
Apply the layout to the local tensor during initializing process.
'''
source_spec
=
construct_default_sharding_spec
(
self
.
local_tensor
,
self
.
device_mesh
)
target_spec
=
convert_layout_to_sharding_spec
(
self
.
dist_layout
)
self
.
local_tensor
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
self
.
local_tensor
,
source_spec
,
target_spec
)
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
if
kwargs
is
None
:
kwargs
=
{}
def
filter_arg
(
arg
):
if
isinstance
(
arg
,
DTensor
):
return
arg
.
local_tensor
else
:
return
arg
args
=
tree_map
(
filter_arg
,
args
)
kwargs
=
tree_map
(
filter_arg
,
kwargs
)
# if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
# and op type.
return
func
(
*
args
,
**
kwargs
)
@
property
def
device_mesh
(
self
):
'''
Return the device mesh of the tensor.
'''
return
self
.
dist_layout
.
device_mesh
@
property
def
sharding_spec
(
self
):
'''
Return the sharding specification of the tensor.
'''
return
self
.
dist_layout
.
sharding_spec
def
to
(
self
,
*
args
,
**
kwargs
):
'''
Move the tensor to a new device or convert the tensor to a new dtype.
'''
self
.
local_tensor
=
self
.
local_tensor
.
to
(
*
args
,
**
kwargs
)
self
.
data_type
=
self
.
local_tensor
.
dtype
self
.
dist_layout
.
device_type
=
self
.
local_tensor
.
device
# TODO: update the device mesh process groups or we should just cache
# both the cpu process groups and the cuda process groups?
return
self
def
to_local
(
self
):
'''
Return the local tensor in this rank.
'''
return
self
.
local_tensor
def
to_global
(
self
):
'''
Recover the global tensor from the distributed tensor.
Note: This function will all_gather the local tensor to the global tensor and it
will not change the layout of the DTensor. This function is mainly used for debugging or
check the correctness of the distributed tensor.
'''
return
to_global
(
self
.
local_tensor
,
convert_layout_to_sharding_spec
(
self
.
dist_layout
))
def
distribute_tensor
(
local_tensor
:
torch
.
Tensor
,
dist_layout
:
Layout
)
->
DTensor
:
'''
Distribute the local tensor to the distributed tensor according to the dist_layout specified.
Args:
local_tensor: tensor to be distributed.
dist_layout: the layout specification of the distributed tensor.
Returns:
A 'DTensor' object.
'''
return
DTensor
(
local_tensor
,
dist_layout
)
def
distribute_module
(
module
:
torch
.
nn
.
Module
,
partition_fn
:
Optional
[
callable
]
=
None
)
->
torch
.
nn
.
Module
:
'''
This function converts all the parameters in the module to DTensor(DParam).
Note: This function is subject to future change as the DParam has not been implemented yet.
'''
for
name
,
param
in
module
.
named_parameters
():
if
param
is
not
None
and
not
isinstance
(
param
,
DTensor
):
# TODO: we could convert the parameter to DParam here,
# the type of the parameter could be an optional argument.
setattr
(
module
,
name
,
torch
.
nn
.
Parameter
(
partition_fn
(
name
,
param
.
data
)))
return
module
def
convert_layout_to_sharding_spec
(
layout
:
Layout
)
->
ShardingSpec
:
'''
Convert the layout from Layout class to ShardingSpec class.
'''
return
ShardingSpec
(
device_mesh
=
layout
.
device_mesh
,
entire_shape
=
layout
.
entire_shape
,
dim_partition_dict
=
layout
.
sharding_spec
.
dim_partition_dict
)
def
construct_default_sharding_spec
(
tensor
:
torch
.
Tensor
,
device_mesh
:
DeviceMesh
,
)
->
ShardingSpec
:
'''
Construct the default sharding specification for the tensor.
'''
return
ShardingSpec
(
device_mesh
=
device_mesh
,
entire_shape
=
tensor
.
shape
,
dim_partition_dict
=
{})
colossalai/tensor/d_tensor/layout.py
0 → 100644
View file @
e414e409
from
dataclasses
import
dataclass
import
torch
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.sharding_spec
import
ShardingSpec
@
dataclass
class
Layout
:
"""Layout of a tensor.
Attributes:
device_mesh: the device mesh to store the tensor distributedly.
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
sharding_spec: the sharding specification to describe how the tensor is sharded.
entire_shape: the entire shape of the global tensor.
"""
device_mesh
:
DeviceMesh
device_type
:
torch
.
device
sharding_spec
:
ShardingSpec
entire_shape
:
torch
.
Size
=
None
tests/test_tensor/test_dtensor.py
0 → 100644
View file @
e414e409
from
functools
import
partial
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.tensor.d_tensor.d_tensor
import
DTensor
,
distribute_tensor
from
colossalai.tensor.d_tensor.layout
import
Layout
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.utils
import
free_port
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
):
super
().
__init__
()
self
.
linear_1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
)
self
.
linear_2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
)
def
forward
(
self
,
x
):
x
=
self
.
linear_1
(
x
)
x
=
self
.
linear_2
(
x
)
return
x
def
check_dtensor
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_model
=
TestModel
(
8
,
8
).
to
(
'cuda'
)
original_tensor
=
torch
.
rand
(
4
,
8
).
to
(
'cuda'
)
compare_output
=
test_model
(
original_tensor
)
device_mesh
=
DeviceMesh
(
torch
.
Tensor
([
0
,
1
,
2
,
3
]),
(
2
,
2
),
init_process_group
=
True
)
target_sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
entire_shape
=
original_tensor
.
shape
,
dim_partition_dict
=
{
0
:
[
0
]})
layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
target_sharding_spec
)
d_tensor
=
DTensor
(
original_tensor
,
layout
)
assert
d_tensor
.
entire_shape
==
original_tensor
.
shape
assert
d_tensor
.
data_type
==
original_tensor
.
dtype
if
rank
in
(
0
,
1
):
assert
d_tensor
.
to_local
().
equal
(
original_tensor
.
narrow
(
0
,
0
,
2
))
elif
rank
in
(
2
,
3
):
assert
d_tensor
.
to_local
().
equal
(
original_tensor
.
narrow
(
0
,
2
,
2
))
else
:
raise
ValueError
(
f
'rank
{
rank
}
is not in the device mesh'
)
assert
d_tensor
.
to_global
().
equal
(
original_tensor
)
output
=
test_model
(
d_tensor
)
if
rank
in
(
0
,
1
):
assert
output
.
equal
(
compare_output
.
narrow
(
0
,
0
,
2
))
elif
rank
in
(
2
,
3
):
assert
output
.
equal
(
compare_output
.
narrow
(
0
,
2
,
2
))
else
:
raise
ValueError
(
f
'rank
{
rank
}
is not in the device mesh'
)
new_sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
entire_shape
=
original_tensor
.
shape
,
dim_partition_dict
=
{
0
:
[
0
,
1
]})
new_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
new_sharding_spec
,
entire_shape
=
original_tensor
.
shape
)
d_tensor
.
layout_convert
(
new_layout
)
if
rank
==
0
:
assert
d_tensor
.
local_tensor
.
equal
(
original_tensor
.
narrow
(
0
,
0
,
1
))
elif
rank
==
1
:
assert
d_tensor
.
local_tensor
.
equal
(
original_tensor
.
narrow
(
0
,
1
,
1
))
elif
rank
==
2
:
assert
d_tensor
.
local_tensor
.
equal
(
original_tensor
.
narrow
(
0
,
2
,
1
))
elif
rank
==
3
:
assert
d_tensor
.
local_tensor
.
equal
(
original_tensor
.
narrow
(
0
,
3
,
1
))
else
:
raise
ValueError
(
f
'rank
{
rank
}
is not in the device mesh'
)
dtensor_from_local
=
distribute_tensor
(
original_tensor
,
new_layout
)
if
rank
==
0
:
assert
dtensor_from_local
.
local_tensor
.
equal
(
original_tensor
.
narrow
(
0
,
0
,
1
))
elif
rank
==
1
:
assert
dtensor_from_local
.
local_tensor
.
equal
(
original_tensor
.
narrow
(
0
,
1
,
1
))
elif
rank
==
2
:
assert
dtensor_from_local
.
local_tensor
.
equal
(
original_tensor
.
narrow
(
0
,
2
,
1
))
elif
rank
==
3
:
assert
dtensor_from_local
.
local_tensor
.
equal
(
original_tensor
.
narrow
(
0
,
3
,
1
))
else
:
raise
ValueError
(
f
'rank
{
rank
}
is not in the device mesh'
)
def
test_dtensor
():
world_size
=
4
run_func
=
partial
(
check_dtensor
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_dtensor
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment