Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2eca4cd3
Unverified
Commit
2eca4cd3
authored
Mar 14, 2023
by
YuliangLiu0306
Committed by
GitHub
Mar 14, 2023
Browse files
[DTensor] refactor dtensor with new components (#3089)
* [DTensor] refactor dtensor with new components * polish
parent
ed8f60b9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
41 deletions
+20
-41
colossalai/tensor/d_tensor/d_tensor.py
colossalai/tensor/d_tensor/d_tensor.py
+14
-30
colossalai/tensor/d_tensor/layout_converter.py
colossalai/tensor/d_tensor/layout_converter.py
+3
-3
tests/test_tensor/test_dtensor/test_dtensor.py
tests/test_tensor/test_dtensor/test_dtensor.py
+3
-8
No files found.
colossalai/tensor/d_tensor/d_tensor.py
View file @
2eca4cd3
...
...
@@ -3,12 +3,11 @@ from typing import Optional
import
torch
from
torch.utils._pytree
import
tree_map
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.d_tensor.layout
import
Layout
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
,
to_global
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.layout
import
Layout
from
.layout_converter
import
LayoutConverter
,
to_global
from
.sharding_spec
import
ShardingSpec
shape_consistency_manager
=
ShapeConsistencyManag
er
()
layout_converter
=
LayoutConvert
er
()
class
DTensor
(
torch
.
Tensor
):
...
...
@@ -17,8 +16,6 @@ class DTensor(torch.Tensor):
self
.
local_tensor
=
local_tensor
self
.
data_type
=
local_tensor
.
dtype
self
.
entire_shape
=
local_tensor
.
shape
if
dist_layout
.
entire_shape
is
None
:
dist_layout
.
entire_shape
=
self
.
entire_shape
self
.
dist_layout
=
dist_layout
self
.
_apply_layout
()
...
...
@@ -36,20 +33,19 @@ class DTensor(torch.Tensor):
'''
Convert the layout of the tensor from source_spec to target_spec.
'''
source_spec
=
convert_layout_to_sharding_spec
(
self
.
dist_layout
)
target_spec
=
convert_layout_to_sharding_spec
(
target_layout
)
self
.
local_tensor
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
self
.
local_tensor
,
source_spec
,
target_spec
)
self
.
local_tensor
=
layout_converter
.
apply
(
self
.
local_tensor
,
self
.
dist_layout
,
target_layout
)
self
.
dist_layout
=
target_layout
def
_apply_layout
(
self
):
'''
Apply the layout to the local tensor during initializing process.
'''
source_spec
=
construct_default_sharding_spec
(
self
.
local_tensor
,
self
.
device_mesh
)
target_spec
=
convert_layout_to_sharding_spec
(
self
.
dist_layout
)
self
.
local_tensor
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
self
.
local_tensor
,
source_spec
,
target_spec
)
source_spec
=
construct_default_sharding_spec
(
self
.
local_tensor
)
source_layout
=
Layout
(
device_mesh
=
self
.
dist_layout
.
device_mesh
,
device_type
=
self
.
dist_layout
.
device_type
,
sharding_spec
=
source_spec
,
entire_shape
=
self
.
entire_shape
)
self
.
local_tensor
=
layout_converter
.
apply
(
self
.
local_tensor
,
source_layout
,
self
.
dist_layout
)
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
...
...
@@ -108,7 +104,7 @@ class DTensor(torch.Tensor):
will not change the layout of the DTensor. This function is mainly used for debugging or
check the correctness of the distributed tensor.
'''
return
to_global
(
self
.
local_tensor
,
convert_layout_to_sharding_spec
(
self
.
dist_layout
)
)
return
to_global
(
self
.
local_tensor
,
self
.
dist_layout
)
def
distribute_tensor
(
local_tensor
:
torch
.
Tensor
,
dist_layout
:
Layout
)
->
DTensor
:
...
...
@@ -139,20 +135,8 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable]
return
module
def
convert_layout_to_sharding_spec
(
layout
:
Layout
)
->
ShardingSpec
:
'''
Convert the layout from Layout class to ShardingSpec class.
'''
return
ShardingSpec
(
device_mesh
=
layout
.
device_mesh
,
entire_shape
=
layout
.
entire_shape
,
dim_partition_dict
=
layout
.
sharding_spec
.
dim_partition_dict
)
def
construct_default_sharding_spec
(
tensor
:
torch
.
Tensor
,
device_mesh
:
DeviceMesh
,
)
->
ShardingSpec
:
def
construct_default_sharding_spec
(
tensor
:
torch
.
Tensor
,)
->
ShardingSpec
:
'''
Construct the default sharding specification for the tensor.
'''
return
ShardingSpec
(
d
evice_mesh
=
device_mesh
,
entire_shap
e
=
tensor
.
shape
,
dim_partition_dict
=
{})
return
ShardingSpec
(
d
im_siz
e
=
tensor
.
dim
()
,
dim_partition_dict
=
{})
colossalai/tensor/d_tensor/layout_converter.py
View file @
2eca4cd3
...
...
@@ -22,21 +22,21 @@ __all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_o
@
dataclass
class
LayoutConverterOptions
:
"""
LayoutConverterOptions is a dataclass which specifies the preferences for
shape consistency
.
LayoutConverterOptions is a dataclass which specifies the preferences for
layout converting
.
"""
# TODO: layout converter option is not implemented yet
pass
def
to_global
(
distributed_tensor
:
torch
.
Tensor
,
layout
:
Layout
)
->
torch
.
Tensor
:
shape_consistency_manag
er
=
LayoutConverter
()
layout_convert
er
=
LayoutConverter
()
global_sharding_spec
=
ShardingSpec
(
distributed_tensor
.
dim
(),
{})
global_layout
=
Layout
(
device_mesh
=
layout
.
device_mesh
,
device_type
=
layout
.
device_type
,
sharding_spec
=
global_sharding_spec
,
entire_shape
=
layout
.
entire_shape
)
with
torch
.
no_grad
():
global_tensor
=
shape_consistency_manag
er
.
apply
(
distributed_tensor
,
layout
,
global_layout
)
global_tensor
=
layout_convert
er
.
apply
(
distributed_tensor
,
layout
,
global_layout
)
return
global_tensor
...
...
tests/test_tensor/test_dtensor/test_dtensor.py
View file @
2eca4cd3
...
...
@@ -4,12 +4,11 @@ import torch
import
torch.multiprocessing
as
mp
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.tensor.d_tensor.d_tensor
import
DTensor
,
distribute_tensor
from
colossalai.tensor.d_tensor.layout
import
Layout
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.
d_tensor.
sharding_spec
import
ShardingSpec
from
colossalai.utils
import
free_port
...
...
@@ -34,9 +33,7 @@ def check_dtensor(rank, world_size, port):
compare_output
=
test_model
(
original_tensor
)
device_mesh
=
DeviceMesh
(
torch
.
Tensor
([
0
,
1
,
2
,
3
]),
(
2
,
2
),
init_process_group
=
True
)
target_sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
entire_shape
=
original_tensor
.
shape
,
dim_partition_dict
=
{
0
:
[
0
]})
target_sharding_spec
=
ShardingSpec
(
dim_size
=
original_tensor
.
dim
(),
dim_partition_dict
=
{
0
:
[
0
]})
layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
target_sharding_spec
,
...
...
@@ -62,9 +59,7 @@ def check_dtensor(rank, world_size, port):
else
:
raise
ValueError
(
f
'rank
{
rank
}
is not in the device mesh'
)
new_sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
entire_shape
=
original_tensor
.
shape
,
dim_partition_dict
=
{
0
:
[
0
,
1
]})
new_sharding_spec
=
ShardingSpec
(
dim_size
=
original_tensor
.
dim
(),
dim_partition_dict
=
{
0
:
[
0
,
1
]})
new_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
new_sharding_spec
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment