Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cd2b0eaa
Unverified
Commit
cd2b0eaa
authored
Mar 07, 2023
by
YuliangLiu0306
Committed by
GitHub
Mar 07, 2023
Browse files
[DTensor] refactor sharding spec (#2987)
* [autoparallel] refactor sharding spec * rename function name
parent
400f6301
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
341 additions
and
7 deletions
+341
-7
colossalai/tensor/d_tensor/__init__.py
colossalai/tensor/d_tensor/__init__.py
+0
-0
colossalai/tensor/d_tensor/layout.py
colossalai/tensor/d_tensor/layout.py
+52
-6
colossalai/tensor/d_tensor/misc.py
colossalai/tensor/d_tensor/misc.py
+14
-0
colossalai/tensor/d_tensor/sharding_spec.py
colossalai/tensor/d_tensor/sharding_spec.py
+237
-0
tests/test_tensor/test_dtensor/test_dtensor.py
tests/test_tensor/test_dtensor/test_dtensor.py
+4
-1
tests/test_tensor/test_dtensor/test_sharding_spec.py
tests/test_tensor/test_dtensor/test_sharding_spec.py
+34
-0
No files found.
colossalai/tensor/d_tensor/__init__.py
0 → 100644
View file @
cd2b0eaa
colossalai/tensor/d_tensor/layout.py
View file @
cd2b0eaa
import
operator
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
reduce
import
torch
import
torch
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.misc
import
DuplicatedShardingDimensionError
,
LayoutException
,
ShardingNotDivisibleError
from
.sharding_spec
import
ShardingSpec
@
dataclass
class
Layout
:
class
Layout
:
"""Layout of a tensor.
"""Layout of a tensor.
...
@@ -16,7 +19,50 @@ class Layout:
...
@@ -16,7 +19,50 @@ class Layout:
sharding_spec: the sharding specification to describe how the tensor is sharded.
sharding_spec: the sharding specification to describe how the tensor is sharded.
entire_shape: the entire shape of the global tensor.
entire_shape: the entire shape of the global tensor.
"""
"""
device_mesh
:
DeviceMesh
device_type
:
torch
.
device
def
__init__
(
self
,
device_mesh
:
DeviceMesh
,
device_type
:
torch
.
device
,
sharding_spec
:
ShardingSpec
,
sharding_spec
:
ShardingSpec
entire_shape
:
torch
.
Size
):
entire_shape
:
torch
.
Size
=
None
self
.
device_mesh
=
device_mesh
self
.
device_type
=
device_type
self
.
sharding_spec
=
sharding_spec
self
.
entire_shape
=
entire_shape
self
.
_sanity_check
()
def
__hash__
(
self
)
->
int
:
return
hash
(
f
'
{
self
.
sharding_spec
}
'
)
def
get_sharded_shape_per_device
(
self
):
sharded_shape
=
list
(
self
.
entire_shape
)
for
dim
,
shard_list
in
self
.
sharding_spec
.
dim_partition_dict
.
items
():
mesh_list
=
[
self
.
device_mesh
.
mesh_shape
[
mesh_dim
]
for
mesh_dim
in
shard_list
]
shard_partitions
=
reduce
(
operator
.
mul
,
mesh_list
,
1
)
assert
sharded_shape
[
dim
]
%
shard_partitions
==
0
,
f
'Cannot shard dimension
{
dim
}
into
{
shard_partitions
}
partitions.'
sharded_shape
[
dim
]
//=
shard_partitions
return
torch
.
Size
(
sharded_shape
)
def
_sanity_check
(
self
):
sharding_spec
=
self
.
sharding_spec
# make sure all axes in logical device mesh only be used once
dim_check_list
=
list
(
range
(
self
.
device_mesh
.
logical_mesh_id
.
dim
()))
for
dim
,
shard_list
in
sharding_spec
.
dim_partition_dict
.
items
():
for
element
in
shard_list
:
if
element
in
dim_check_list
:
dim_check_list
.
remove
(
element
)
else
:
raise
DuplicatedShardingDimensionError
(
f
"find an invalid sharding axis
{
element
}
in dim_partition_dict in tensor dimension
{
dim
}
."
)
# make sure that the sharding for a dimension is divisible by the number of devices
for
dim
,
shard_list
in
sharding_spec
.
dim_partition_dict
.
items
():
tensor_dim_size
=
self
.
entire_shape
[
dim
]
num_devices
=
1
for
element
in
shard_list
:
num_devices
*=
self
.
device_mesh
.
mesh_shape
[
element
]
if
tensor_dim_size
%
num_devices
!=
0
:
raise
ShardingNotDivisibleError
(
f
'The size of dimension at index
{
dim
}
is
{
tensor_dim_size
}
, it cannot be sharded over
{
num_devices
}
devices.'
)
colossalai/tensor/d_tensor/misc.py
0 → 100644
View file @
cd2b0eaa
class
LayoutException
(
Exception
):
pass
class
DuplicatedShardingDimensionError
(
LayoutException
):
pass
class
ShardingNotDivisibleError
(
LayoutException
):
pass
class
ShardingOutOfIndexError
(
LayoutException
):
pass
colossalai/tensor/d_tensor/sharding_spec.py
0 → 100644
View file @
cd2b0eaa
from
copy
import
deepcopy
from
typing
import
Dict
,
List
from
..utils
import
merge_same_dim_mesh_list
from
.misc
import
ShardingOutOfIndexError
__all__
=
[
'DimSpec'
,
'ShardingException'
,
'ShardingSpec'
]
ALLGATHER_COST
=
20
SHARD_COST
=
5
STEP_PENALTY
=
6
NAN
=
'nan'
class
DimSpec
:
'''
Sharding spec for single dimension of the sharded tensor decribe the sharding dimension of
logical device mesh and give a method to compute the difference between them.
This class is used internally in ShardingSpec.
Argument:
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
Otherwise, the element in shard_list means the data will be sharded in that dimension.
'''
def
__init__
(
self
,
shard_list
):
self
.
is_replica
=
len
(
shard_list
)
==
0
self
.
shard_list
=
shard_list
self
.
build_difference_2d_dict
()
def
__eq__
(
self
,
other
):
return
str
(
self
)
==
str
(
other
)
def
__repr__
(
self
):
if
self
.
is_replica
:
return
'R'
target
=
'S'
for
dim
in
self
.
shard_list
:
target
+=
str
(
dim
)
return
target
def
_convert_str_to_shard_list
(
self
,
str_spec
):
'''
Conver str_spec into shard_list.
Argument:
str_spec(str): dim spec in str type.
'''
if
str_spec
==
'R'
:
return
[]
if
str_spec
==
'S0'
:
return
[
0
]
if
str_spec
==
'S1'
:
return
[
1
]
if
str_spec
==
'S01'
:
return
[
0
,
1
]
def
build_difference_2d_dict
(
self
):
'''
Build a difference maping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
'''
source_spec_list
=
[
'R'
,
'S0'
,
'S1'
,
'S01'
]
target_spec_list
=
[
'R'
,
'S0'
,
'S1'
,
'S01'
]
difference_dict
=
{}
for
source_spec
in
source_spec_list
:
for
target_spec
in
target_spec_list
:
legal_sharding_dims
=
[]
spec_pair
=
(
deepcopy
(
source_spec
),
deepcopy
(
target_spec
))
source_shard_list
=
self
.
_convert_str_to_shard_list
(
source_spec
)
target_shard_list
=
self
.
_convert_str_to_shard_list
(
target_spec
)
# source same as target
if
source_shard_list
==
target_shard_list
:
difference
=
0
# all_gather(source) -> target
elif
len
(
source_shard_list
)
==
len
(
target_shard_list
)
+
1
and
source_shard_list
[:
-
1
]
==
target_shard_list
:
difference
=
ALLGATHER_COST
# shard(source) -> target
elif
len
(
source_shard_list
)
==
len
(
target_shard_list
)
-
1
and
source_shard_list
==
target_shard_list
[:
-
1
]
and
target_shard_list
[
-
1
]
not
in
source_shard_list
:
difference
=
SHARD_COST
# S1 -> S0 or S0 -> S1
elif
len
(
source_shard_list
)
==
len
(
target_shard_list
):
# source -> R -> target
difference
=
ALLGATHER_COST
+
STEP_PENALTY
+
SHARD_COST
# R -> S01
elif
len
(
source_shard_list
)
==
len
(
target_shard_list
)
-
2
:
difference
=
SHARD_COST
+
STEP_PENALTY
+
SHARD_COST
# S01 -> R
elif
len
(
source_shard_list
)
==
len
(
target_shard_list
)
+
2
:
difference
=
ALLGATHER_COST
+
STEP_PENALTY
+
ALLGATHER_COST
# S1 -> S01
elif
len
(
source_shard_list
)
==
len
(
target_shard_list
)
-
1
:
difference
=
ALLGATHER_COST
+
STEP_PENALTY
+
SHARD_COST
+
STEP_PENALTY
+
SHARD_COST
# S01 -> S1
elif
len
(
source_shard_list
)
==
len
(
target_shard_list
)
+
1
:
difference
=
ALLGATHER_COST
+
STEP_PENALTY
+
ALLGATHER_COST
+
STEP_PENALTY
+
SHARD_COST
else
:
difference
=
NAN
difference_dict
[
spec_pair
]
=
difference
self
.
difference_dict
=
difference_dict
def
dim_diff
(
self
,
other
):
'''
The difference between two _DimSpec.
Argument:
other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
'''
difference
=
self
.
difference_dict
[(
str
(
self
),
str
(
other
))]
return
difference
class
ShardingSpec
:
'''
Sharding spec describes how to shard a tensor with dim_size dimensions. The sharding sequence looks like
[R, R, S0, S1], which means
Argument:
dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension.
sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
'''
def
__init__
(
self
,
dim_size
:
int
,
dim_partition_dict
:
Dict
[
int
,
List
[
int
]]
=
None
,
sharding_sequence
:
List
[
DimSpec
]
=
None
):
self
.
dims
=
dim_size
self
.
dim_partition_dict
=
dim_partition_dict
self
.
sharding_sequence
=
sharding_sequence
if
self
.
sharding_sequence
is
None
:
assert
self
.
dim_partition_dict
is
not
None
,
f
'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
self
.
dim_partition_dict
=
merge_same_dim_mesh_list
(
dim_size
=
self
.
dims
,
dim_partition_dict
=
self
.
dim_partition_dict
)
self
.
sharding_sequence
=
self
.
convert_dict_to_shard_sequence
()
elif
self
.
dim_partition_dict
is
None
:
assert
self
.
sharding_sequence
is
not
None
,
f
'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
self
.
dim_partition_dict
=
self
.
convert_shard_sequence_to_dict
()
self
.
_sanity_check
()
def
_sanity_check
(
self
):
if
len
(
self
.
sharding_sequence
)
>
self
.
dims
:
raise
ShardingOutOfIndexError
(
f
'sharding_sequence should have
{
self
.
dims
}
elements, but got index
{
len
(
self
.
sharding_sequence
)
}
.'
)
if
max
(
list
(
self
.
dim_partition_dict
.
keys
()))
>=
self
.
dims
:
raise
ShardingOutOfIndexError
(
f
'the key of dim_partition_dict should be less than
{
self
.
dims
}
, but got
{
max
(
list
(
self
.
dim_partition_dict
.
keys
()))
}
.'
)
def
__repr__
(
self
):
res_list
=
[
"ShardingSpec:"
]
res_list
.
append
(
f
"
\n\t
shard_sequence: "
+
","
.
join
(
str
(
dimspec
)
for
dimspec
in
self
.
sharding_sequence
))
return
' '
.
join
(
res_list
)
def
convert_dict_to_shard_sequence
(
self
):
'''
Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence.
'''
sharding_sequence
=
[
DimSpec
([])]
*
self
.
dims
for
dim
,
shard_list
in
self
.
dim_partition_dict
.
items
():
sharding_sequence
[
dim
]
=
DimSpec
(
shard_list
)
return
sharding_sequence
def
convert_shard_sequence_to_dict
(
self
):
'''
Convert sharding_sequence into dim_partition_dict.
'''
new_dim_partition_dict
=
{}
for
index
,
dim_spec
in
enumerate
(
self
.
sharding_sequence
):
if
not
dim_spec
.
is_replica
:
if
index
not
in
new_dim_partition_dict
:
new_dim_partition_dict
[
index
]
=
[]
new_dim_partition_dict
[
index
].
extend
(
dim_spec
.
shard_list
)
return
new_dim_partition_dict
def
spec_diff
(
self
,
other
):
'''
This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
pair of sharding sequence.
Example:
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
dim_partition_dict_to_compare = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
Output:
25
Argument:
other(ShardingSpec): The ShardingSpec to compared with.
Return:
difference(int): Difference between two ShardingSpec.
'''
assert
len
(
self
.
sharding_sequence
)
==
len
(
other
.
sharding_sequence
),
f
'Cannot compare difference for two sharding specs with different length.'
difference
=
0
for
orig_dim_spec
,
other_dim_spec
in
zip
(
self
.
sharding_sequence
,
other
.
sharding_sequence
):
difference
+=
orig_dim_spec
.
dim_diff
(
other_dim_spec
)
return
difference
tests/test_tensor/test_dtensor.py
→
tests/test_tensor/test_dtensor
/test_dtensor
.py
View file @
cd2b0eaa
...
@@ -37,7 +37,10 @@ def check_dtensor(rank, world_size, port):
...
@@ -37,7 +37,10 @@ def check_dtensor(rank, world_size, port):
target_sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
target_sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
entire_shape
=
original_tensor
.
shape
,
entire_shape
=
original_tensor
.
shape
,
dim_partition_dict
=
{
0
:
[
0
]})
dim_partition_dict
=
{
0
:
[
0
]})
layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
target_sharding_spec
)
layout
=
Layout
(
device_mesh
=
device_mesh
,
device_type
=
torch
.
device
(
'cuda'
),
sharding_spec
=
target_sharding_spec
,
entire_shape
=
original_tensor
.
shape
)
d_tensor
=
DTensor
(
original_tensor
,
layout
)
d_tensor
=
DTensor
(
original_tensor
,
layout
)
assert
d_tensor
.
entire_shape
==
original_tensor
.
shape
assert
d_tensor
.
entire_shape
==
original_tensor
.
shape
...
...
tests/test_tensor/test_dtensor/test_sharding_spec.py
0 → 100644
View file @
cd2b0eaa
import
operator
from
functools
import
reduce
from
colossalai.tensor.d_tensor.sharding_spec
import
ALLGATHER_COST
,
SHARD_COST
,
STEP_PENALTY
,
ShardingSpec
def
test_sharding_spec
():
dims
=
4
dim_partition_dict_0
=
{
0
:
[
0
,
1
]}
# DistSpec:
# shard_sequence: S01,R,R,R
sharding_spec_0
=
ShardingSpec
(
dims
,
dim_partition_dict
=
dim_partition_dict_0
)
assert
str
(
sharding_spec_0
.
sharding_sequence
)
==
"[S01, R, R, R]"
dim_partition_dict_1
=
{
1
:
[
0
,
1
]}
# DistSpec:
# shard_sequence: R,S01,R,R
sharding_spec_1
=
ShardingSpec
(
dims
,
dim_partition_dict
=
dim_partition_dict_1
)
assert
str
(
sharding_spec_1
.
sharding_sequence
)
==
"[R, S01, R, R]"
dim_spec_list_0
=
[
dim_spec
for
dim_spec
in
sharding_spec_0
.
sharding_sequence
]
dim_spec_list_1
=
[
dim_spec
for
dim_spec
in
sharding_spec_1
.
sharding_sequence
]
assert
dim_spec_list_0
[
0
].
dim_diff
(
dim_spec_list_1
[
0
])
==
ALLGATHER_COST
+
STEP_PENALTY
+
ALLGATHER_COST
assert
dim_spec_list_0
[
1
].
dim_diff
(
dim_spec_list_1
[
1
])
==
SHARD_COST
+
STEP_PENALTY
+
SHARD_COST
assert
dim_spec_list_0
[
2
].
dim_diff
(
dim_spec_list_1
[
2
])
==
0
assert
dim_spec_list_0
[
3
].
dim_diff
(
dim_spec_list_1
[
3
])
==
0
assert
sharding_spec_0
.
spec_diff
(
sharding_spec_1
)
==
\
reduce
(
operator
.
add
,
[
dim_spec_list_0
[
i
].
dim_diff
(
dim_spec_list_1
[
i
])
for
i
in
range
(
dims
)],
0
)
if
__name__
==
'__main__'
:
test_sharding_spec
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment