Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2eca4cd3
Unverified
Commit
2eca4cd3
authored
Mar 14, 2023
by
YuliangLiu0306
Committed by
GitHub
Mar 14, 2023
Browse files
[DTensor] refactor dtensor with new components (#3089)
* [DTensor] refactor dtensor with new components * polish
parent
ed8f60b9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
41 deletions
+20
-41
colossalai/tensor/d_tensor/d_tensor.py
colossalai/tensor/d_tensor/d_tensor.py
+14
-30
colossalai/tensor/d_tensor/layout_converter.py
colossalai/tensor/d_tensor/layout_converter.py
+3
-3
tests/test_tensor/test_dtensor/test_dtensor.py
tests/test_tensor/test_dtensor/test_dtensor.py
+3
-8
No files found.
colossalai/tensor/d_tensor/d_tensor.py
View file @
2eca4cd3
...
@@ -3,12 +3,11 @@ from typing import Optional
...
@@ -3,12 +3,11 @@ from typing import Optional
import
torch
import
torch
from
torch.utils._pytree
import
tree_map
from
torch.utils._pytree
import
tree_map
from
colossalai.device.device_mesh
import
DeviceMesh
from
.layout
import
Layout
from
colossalai.tensor.d_tensor.layout
import
Layout
from
.layout_converter
import
LayoutConverter
,
to_global
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
,
to_global
from
.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
shape_consistency_manager
=
ShapeConsistencyManag
er
()
layout_converter
=
LayoutConvert
er
()
class
DTensor
(
torch
.
Tensor
):
class
DTensor
(
torch
.
Tensor
):
...
@@ -17,8 +16,6 @@ class DTensor(torch.Tensor):
...
@@ -17,8 +16,6 @@ class DTensor(torch.Tensor):
self
.
local_tensor
=
local_tensor
self
.
local_tensor
=
local_tensor
self
.
data_type
=
local_tensor
.
dtype
self
.
data_type
=
local_tensor
.
dtype
self
.
entire_shape
=
local_tensor
.
shape
self
.
entire_shape
=
local_tensor
.
shape
if
dist_layout
.
entire_shape
is
None
:
dist_layout
.
entire_shape
=
self
.
entire_shape
self
.
dist_layout
=
dist_layout
self
.
dist_layout
=
dist_layout
self
.
_apply_layout
()
self
.
_apply_layout
()
...
@@ -36,20 +33,19 @@ class DTensor(torch.Tensor):
...
@@ -36,20 +33,19 @@ class DTensor(torch.Tensor):
'''
'''
Convert the layout of the tensor from source_spec to target_spec.
Convert the layout of the tensor from source_spec to target_spec.
'''
'''
source_spec
=
convert_layout_to_sharding_spec
(
self
.
dist_layout
)
self
.
local_tensor
=
layout_converter
.
apply
(
self
.
local_tensor
,
self
.
dist_layout
,
target_layout
)
target_spec
=
convert_layout_to_sharding_spec
(
target_layout
)
self
.
local_tensor
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
self
.
local_tensor
,
source_spec
,
target_spec
)
self
.
dist_layout
=
target_layout
self
.
dist_layout
=
target_layout
def
_apply_layout
(
self
):
def
_apply_layout
(
self
):
'''
'''
Apply the layout to the local tensor during initializing process.
Apply the layout to the local tensor during initializing process.
'''
'''
source_spec
=
construct_default_sharding_spec
(
self
.
local_tensor
,
self
.
device_mesh
)
source_spec
=
construct_default_sharding_spec
(
self
.
local_tensor
)
target_spec
=
convert_layout_to_sharding_spec
(
self
.
dist_layout
)
source_layout
=
Layout
(
device_mesh
=
self
.
dist_layout
.
device_mesh
,
self
.
local_tensor
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
device_type
=
self
.
dist_layout
.
device_type
,
self
.
local_tensor
,
source_spec
,
target_spec
)
sharding_spec
=
source_spec
,
entire_shape
=
self
.
entire_shape
)
self
.
local_tensor
=
layout_converter
.
apply
(
self
.
local_tensor
,
source_layout
,
self
.
dist_layout
)
@
classmethod
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
...
@@ -108,7 +104,7 @@ class DTensor(torch.Tensor):
...
@@ -108,7 +104,7 @@ class DTensor(torch.Tensor):
will not change the layout of the DTensor. This function is mainly used for debugging or
will not change the layout of the DTensor. This function is mainly used for debugging or
check the correctness of the distributed tensor.
check the correctness of the distributed tensor.
'''
'''
return
to_global
(
self
.
local_tensor
,
convert_layout_to_sharding_spec
(
self
.
dist_layout
)
)
return
to_global
(
self
.
local_tensor
,
self
.
dist_layout
)
def
distribute_tensor
(
local_tensor
:
torch
.
Tensor
,
dist_layout
:
Layout
)
->
DTensor
:
def
distribute_tensor
(
local_tensor
:
torch
.
Tensor
,
dist_layout
:
Layout
)
->
DTensor
:
...
@@ -139,20 +135,8 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable]
...
@@ -139,20 +135,8 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable]
return
module
return
module
def
convert_layout_to_sharding_spec
(
layout
:
Layout
)
->
ShardingSpec
:
def
construct_default_sharding_spec
(
tensor
:
torch
.
Tensor
,)
->
ShardingSpec
:
'''
Convert the layout from Layout class to ShardingSpec class.
'''
return
ShardingSpec
(
device_mesh
=
layout
.
device_mesh
,
entire_shape
=
layout
.
entire_shape
,
dim_partition_dict
=
layout
.
sharding_spec
.
dim_partition_dict
)
def
construct_default_sharding_spec
(
tensor
:
torch
.
Tensor
,
device_mesh
:
DeviceMesh
,
)
->
ShardingSpec
:
'''
'''
Construct the default sharding specification for the tensor.
Construct the default sharding specification for the tensor.
'''
'''
return
ShardingSpec
(
d
evice_mesh
=
device_mesh
,
entire_shap
e
=
tensor
.
shape
,
dim_partition_dict
=
{})
return
ShardingSpec
(
d
im_siz
e
=
tensor
.
dim
()
,
dim_partition_dict
=
{})
colossalai/tensor/d_tensor/layout_converter.py
View file @
2eca4cd3
...
@@ -22,21 +22,21 @@ __all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_o
...
@@ -22,21 +22,21 @@ __all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_o
@
dataclass
@
dataclass
class
LayoutConverterOptions
:
class
LayoutConverterOptions
:
"""
"""
LayoutConverterOptions is a dataclass which specifies the preferences for
shape consistency
.
LayoutConverterOptions is a dataclass which specifies the preferences for
layout converting
.
"""
"""
# TODO: layout converter option is not implemented yet
# TODO: layout converter option is not implemented yet
pass
pass
def
to_global
(
distributed_tensor
:
torch
.
Tensor
,
layout
:
Layout
)
->
torch
.
Tensor
:
def
to_global
(
distributed_tensor
:
torch
.
Tensor
,
layout
:
Layout
)
->
torch
.
Tensor
:
shape_consistency_manag
er
=
LayoutConverter
()
layout_convert
er
=
LayoutConverter
()
global_sharding_spec
=
ShardingSpec
(
distributed_tensor
.
dim
(),
{})
global_sharding_spec
=
ShardingSpec
(
distributed_tensor
.
dim
(),
{})
global_layout
=
Layout
(
device_mesh
=
layout
.
device_mesh
,
global_layout
=
Layout
(
device_mesh
=
layout
.
device_mesh
,
device_type
=
layout
.
device_type
,
device_type
=
layout
.
device_type
,
sharding_spec
=
global_sharding_spec
,
sharding_spec
=
global_sharding_spec
,
entire_shape
=
layout
.
entire_shape
)
entire_shape
=
layout
.
entire_shape
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
global_tensor
=
shape_consistency_manag
er
.
apply
(
distributed_tensor
,
layout
,
global_layout
)
global_tensor
=
layout_convert
er
.
apply
(
distributed_tensor
,
layout
,
global_layout
)
return
global_tensor
return
global_tensor
...
...
tests/test_tensor/test_dtensor/test_dtensor.py
View file @
2eca4cd3
...
@@ -4,12 +4,11 @@ import torch
...
@@ -4,12 +4,11 @@ import torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.tensor.d_tensor.d_tensor
import
DTensor
,
distribute_tensor
from
colossalai.tensor.d_tensor.d_tensor
import
DTensor
,
distribute_tensor
from
colossalai.tensor.d_tensor.layout
import
Layout
from
colossalai.tensor.d_tensor.layout
import
Layout
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.
d_tensor.
sharding_spec
import
ShardingSpec
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
...
@@ -34,9 +33,7 @@ def check_dtensor(rank, world_size, port):
...
@@ -34,9 +33,7 @@ def check_dtensor(rank, world_size, port):
compare_output
=
test_model
(
original_tensor
)
compare_output
=
test_model
(
original_tensor
)
device_mesh
=
DeviceMesh
(
torch
.
Tensor
([
0
,
1
,
2
,
3
]),
(
2
,
2
),
init_process_group
=
True
)
device_mesh
=
DeviceMesh
(
torch
.
Tensor
([
0
,
1
,
2
,
3
]),
(
2
,
2
),
init_process_group
=
True
)
target_sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
target_sharding_spec
=
ShardingSpec
(
dim_size
=
original_tensor
.
dim
(),
dim_partition_dict
=
{
0
:
[
0
]})
entire_shape
=
original_tensor
.
shape
,
dim_partition_dict
=
{
0
:
[
0
]})
layout
=
Layout
(
device_mesh
=
device_mesh
,
layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
target_sharding_spec
,
sharding_spec
=
target_sharding_spec
,
...
@@ -62,9 +59,7 @@ def check_dtensor(rank, world_size, port):
...
@@ -62,9 +59,7 @@ def check_dtensor(rank, world_size, port):
else
:
else
:
raise
ValueError
(
f
'rank
{
rank
}
is not in the device mesh'
)
raise
ValueError
(
f
'rank
{
rank
}
is not in the device mesh'
)
new_sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
new_sharding_spec
=
ShardingSpec
(
dim_size
=
original_tensor
.
dim
(),
dim_partition_dict
=
{
0
:
[
0
,
1
]})
entire_shape
=
original_tensor
.
shape
,
dim_partition_dict
=
{
0
:
[
0
,
1
]})
new_layout
=
Layout
(
device_mesh
=
device_mesh
,
new_layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
new_sharding_spec
,
sharding_spec
=
new_sharding_spec
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment