Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
eee84908
Unverified
Commit
eee84908
authored
Oct 19, 2022
by
Frank Lee
Committed by
GitHub
Oct 19, 2022
Browse files
[autoparallel] handled illegal sharding strategy (#1728)
* [autoparallel] handled illegal sharding strategy * polish code
parent
cbe9a4cb
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
175 additions
and
61 deletions
+175
-61
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
.../tensor_shard/node_handler/strategy/strategy_generator.py
+26
-5
colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
...hard/node_handler/strategy/unary_elementwise_generator.py
+2
-6
colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
...lel/tensor_shard/node_handler/strategy/where_generator.py
+2
-6
colossalai/auto_parallel/tensor_shard/utils/__init__.py
colossalai/auto_parallel/tensor_shard/utils/__init__.py
+2
-2
colossalai/auto_parallel/tensor_shard/utils/misc.py
colossalai/auto_parallel/tensor_shard/utils/misc.py
+12
-7
colossalai/tensor/sharding_spec.py
colossalai/tensor/sharding_spec.py
+47
-12
tests/test_auto_parallel/__init__.py
tests/test_auto_parallel/__init__.py
+0
-0
tests/test_auto_parallel/test_tensor_shard/__init__.py
tests/test_auto_parallel/test_tensor_shard/__init__.py
+0
-0
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py
...st_deprecated_op_handler/test_deprecated_bcast_handler.py
+7
-3
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py
...est_deprecated_op_handler/test_deprecated_bcast_matmul.py
+5
-3
tests/test_auto_parallel/test_tensor_shard/test_node_handler/__init__.py
..._parallel/test_tensor_shard/test_node_handler/__init__.py
+0
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/common.py
...to_parallel/test_tensor_shard/test_node_handler/common.py
+37
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
...est_tensor_shard/test_node_handler/test_linear_handler.py
+22
-6
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py
...hard/test_node_handler/test_unary_element_wise_handler.py
+0
-1
tests/test_tensor/test_sharded_linear.py
tests/test_tensor/test_sharded_linear.py
+10
-8
tests/test_tensor/test_sharding_spec.py
tests/test_tensor/test_sharding_spec.py
+3
-2
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
View file @
eee84908
...
@@ -4,13 +4,12 @@ from functools import reduce
...
@@ -4,13 +4,12 @@ from functools import reduce
from
typing
import
Any
,
Dict
,
List
,
Union
from
typing
import
Any
,
Dict
,
List
,
Union
import
torch
import
torch
from
torch.fx
import
Node
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
ShardingStrategy
,
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
ShardingStrategy
,
TrainCycleItem
)
TrainCycleItem
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
,
CommSpec
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
,
CommSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
torch.fx
import
Node
class
StrategyGenerator
(
ABC
):
class
StrategyGenerator
(
ABC
):
...
@@ -24,6 +23,9 @@ class StrategyGenerator(ABC):
...
@@ -24,6 +23,9 @@ class StrategyGenerator(ABC):
self
.
op_data
=
operation_data_mapping
self
.
op_data
=
operation_data_mapping
self
.
device_mesh
=
device_mesh
self
.
device_mesh
=
device_mesh
# validate the whether operation data is of desired value
self
.
validate
()
@
property
@
property
def
has_bias
(
self
):
def
has_bias
(
self
):
"""
"""
...
@@ -102,9 +104,9 @@ class StrategyGenerator(ABC):
...
@@ -102,9 +104,9 @@ class StrategyGenerator(ABC):
comm_cost
=
TrainCycleItem
(
fwd
=
0
,
bwd
=
0
,
total
=
0
)
comm_cost
=
TrainCycleItem
(
fwd
=
0
,
bwd
=
0
,
total
=
0
)
def
_compute_and_add
(
data
:
OperationData
,
comm_spec
:
CommSpec
):
def
_compute_and_add
(
op_
data
:
OperationData
,
comm_spec
:
CommSpec
):
num_ele_in_comm
=
comm_spec
.
get_comm_cost
()
num_ele_in_comm
=
comm_spec
.
get_comm_cost
()
dtype
=
op
erand
.
data
.
dtype
dtype
=
op
_data
.
data
.
dtype
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
for
phase
,
cost
in
num_ele_in_comm
.
items
():
for
phase
,
cost
in
num_ele_in_comm
.
items
():
num_ele_in_comm
[
phase
]
=
num_ele_in_comm
[
phase
]
*
size_per_elem_bytes
num_ele_in_comm
[
phase
]
=
num_ele_in_comm
[
phase
]
*
size_per_elem_bytes
...
@@ -151,11 +153,30 @@ class StrategyGenerator(ABC):
...
@@ -151,11 +153,30 @@ class StrategyGenerator(ABC):
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
reduce
(
operator
.
mul
,
sharded_shape
)
*
size_per_elem_bytes
return
reduce
(
operator
.
mul
,
sharded_shape
)
*
size_per_elem_bytes
@
abstractmethod
def
generate
(
self
)
->
List
[
ShardingStrategy
]:
def
generate
(
self
)
->
List
[
ShardingStrategy
]:
"""
"""
Generate all possible sharding strategies for this operation.
Generate all possible sharding strategies for this operation.
"""
"""
strategies
=
self
.
collate_strategies
()
# some strategies may be None as ignore_sharding_exception may return None
# when ShardingSpecException occurs.
# thus, remove those None values
strategies
=
[
strategy
for
strategy
in
strategies
if
strategy
]
# update the costs
# update mete info on cost
# these update methods are all in-place, the default method will do nothing
# the cost info will only be added if the child class overrides these methods
for
strategy
in
strategies
:
self
.
update_communication_cost
(
strategy
)
self
.
update_compute_cost
(
strategy
)
self
.
update_memory_cost
(
strategy
)
return
strategies
@
abstractmethod
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
pass
pass
@
abstractmethod
@
abstractmethod
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
View file @
eee84908
import
copy
import
copy
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
)
...
@@ -48,7 +49,7 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
...
@@ -48,7 +49,7 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
def
generate
(
self
)
:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]
:
strategy_list
=
[]
strategy_list
=
[]
# For element-wise function, we keep the sharding spec of output node same as
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# the input. Therefore, the different strategies of input node with same
...
@@ -73,9 +74,4 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
...
@@ -73,9 +74,4 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
for
strategy
in
strategy_list
:
self
.
update_communication_cost
(
strategy
)
self
.
update_compute_cost
(
strategy
)
self
.
update_memory_cost
(
strategy
)
return
strategy_list
return
strategy_list
colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
View file @
eee84908
import
copy
import
copy
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
)
from
colossalai.auto_parallel.tensor_shard.utils
import
(
enumerate_all_possible_1d_sharding
,
from
colossalai.auto_parallel.tensor_shard.utils
import
(
enumerate_all_possible_1d_sharding
,
...
@@ -78,7 +79,7 @@ class WhereGenerator(StrategyGenerator):
...
@@ -78,7 +79,7 @@ class WhereGenerator(StrategyGenerator):
return
dim_partition_list
return
dim_partition_list
def
generate
(
self
)
:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]
:
'''
'''
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
'''
'''
...
@@ -90,9 +91,4 @@ class WhereGenerator(StrategyGenerator):
...
@@ -90,9 +91,4 @@ class WhereGenerator(StrategyGenerator):
strategy
=
self
.
_generate_strategy_with_dim_partition
(
dim_partition
)
strategy
=
self
.
_generate_strategy_with_dim_partition
(
dim_partition
)
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
for
strategy
in
strategy_list
:
self
.
update_communication_cost
(
strategy
)
self
.
update_compute_cost
(
strategy
)
self
.
update_memory_cost
(
strategy
)
return
strategy_list
return
strategy_list
colossalai/auto_parallel/tensor_shard/utils/__init__.py
View file @
eee84908
from
.broadcast
import
(
BroadcastType
,
get_broadcast_shape
,
is_broadcastable
,
recover_sharding_spec_for_broadcast_shape
)
from
.broadcast
import
(
BroadcastType
,
get_broadcast_shape
,
is_broadcastable
,
recover_sharding_spec_for_broadcast_shape
)
from
.factory
import
generate_resharding_costs
,
generate_sharding_spec
from
.factory
import
generate_resharding_costs
,
generate_sharding_spec
from
.misc
import
exception
_handler
from
.misc
import
ignore_sharding_
exception
from
.sharding
import
(
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
,
generate_sharding_size
,
from
.sharding
import
(
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
,
generate_sharding_size
,
switch_partition_dim
,
update_partition_dim
)
switch_partition_dim
,
update_partition_dim
)
__all__
=
[
__all__
=
[
'BroadcastType'
,
'get_broadcast_shape'
,
'is_broadcastable'
,
'recover_sharding_spec_for_broadcast_shape'
,
'BroadcastType'
,
'get_broadcast_shape'
,
'is_broadcastable'
,
'recover_sharding_spec_for_broadcast_shape'
,
'generate_resharding_costs'
,
'generate_sharding_spec'
,
'exception
_handler
'
,
'switch_partition_dim'
,
'generate_resharding_costs'
,
'generate_sharding_spec'
,
'
ignore_sharding_
exception'
,
'switch_partition_dim'
,
'update_partition_dim'
,
'enumerate_all_possible_1d_sharding'
,
'enumerate_all_possible_2d_sharding'
,
'update_partition_dim'
,
'enumerate_all_possible_1d_sharding'
,
'enumerate_all_possible_2d_sharding'
,
'generate_sharding_size'
'generate_sharding_size'
]
]
colossalai/auto_parallel/tensor_shard/utils/misc.py
View file @
eee84908
import
functools
import
functools
import
warnings
__all__
=
[
'exception_handler'
]
from
colossalai.logging
import
get_dist_logger
from
colossalai.tensor.sharding_spec
import
ShardingSpecException
__all__
=
[
'ignore_sharding_exception'
]
def
exception_handler
(
func
):
def
ignore_sharding_exception
(
func
):
"""
"""
A function wrapper to handle the AssertionError in the function.
A function wrapper to handle the ShardingSpecException in the function.
If ShardingSpecException occurs, this function will return None.
Usage:
Usage:
# mute the assertion error in the function
# mute the assertion error in the function
@exception
_handler
@
ignore_sharding_
exception
def do_something():
def do_something():
...
...
"""
"""
...
@@ -18,9 +21,11 @@ def exception_handler(func):
...
@@ -18,9 +21,11 @@ def exception_handler(func):
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
try
:
try
:
logger
=
get_dist_logger
()
rst
=
func
(
*
args
,
**
kwargs
)
rst
=
func
(
*
args
,
**
kwargs
)
return
rst
return
rst
except
AssertionError
as
e
:
except
ShardingSpecException
as
e
:
warnings
.
warn
(
f
'
{
e
}
'
)
logger
.
debug
(
e
)
return
None
return
wrapper
return
wrapper
colossalai/tensor/sharding_spec.py
View file @
eee84908
import
torch
import
operator
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.utils
import
all_gather_simulator
,
all_to_all_simulator
,
shard_simulator
from
copy
import
deepcopy
from
copy
import
deepcopy
from
enum
import
Enum
from
enum
import
Enum
from
functools
import
reduce
from
functools
import
reduce
import
operator
import
torch
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.utils
import
(
all_gather_simulator
,
all_to_all_simulator
,
shard_simulator
)
__all__
=
[
'_DimSpec'
,
'ShardingException'
,
'ShardingSpec'
]
__all__
=
[
'_DimSpec'
,
'ShardingException'
,
'ShardingSpec'
]
...
@@ -138,7 +140,19 @@ class _DimSpec:
...
@@ -138,7 +140,19 @@ class _DimSpec:
return
difference
return
difference
class
ShardingException
(
Exception
):
class
ShardingSpecException
(
Exception
):
pass
class
ShardingOutOfIndexError
(
ShardingSpecException
):
pass
class
DuplicatedShardingDimensionError
(
ShardingSpecException
):
pass
class
ShardingNotDivisibleError
(
ShardingSpecException
):
pass
pass
...
@@ -156,7 +170,11 @@ class ShardingSpec:
...
@@ -156,7 +170,11 @@ class ShardingSpec:
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
'''
'''
def
__init__
(
self
,
device_mesh
,
entire_shape
,
dim_partition_dict
=
None
,
sharding_sequence
=
None
):
def
__init__
(
self
,
device_mesh
:
DeviceMesh
,
entire_shape
:
torch
.
Size
,
dim_partition_dict
=
None
,
sharding_sequence
=
None
):
self
.
device_mesh
=
device_mesh
self
.
device_mesh
=
device_mesh
self
.
entire_shape
=
entire_shape
self
.
entire_shape
=
entire_shape
self
.
dim_partition_dict
=
dim_partition_dict
self
.
dim_partition_dict
=
dim_partition_dict
...
@@ -174,19 +192,36 @@ class ShardingSpec:
...
@@ -174,19 +192,36 @@ class ShardingSpec:
return
' '
.
join
(
res_list
)
return
' '
.
join
(
res_list
)
def
_sanity_check
(
self
):
def
_sanity_check
(
self
):
'''
# make sure all axes in logical device mesh only be used once
In sanity check, we need make sure all axes in logical device mesh only be used
dim_check_list
=
list
(
range
(
self
.
device_mesh
.
logical_mesh_id
.
dim
()))
once.
'''
dim_check_list
=
[
i
for
i
in
range
(
self
.
device_mesh
.
logical_mesh_id
.
dim
())]
for
dim
,
shard_list
in
self
.
dim_partition_dict
.
items
():
for
dim
,
shard_list
in
self
.
dim_partition_dict
.
items
():
for
element
in
shard_list
:
for
element
in
shard_list
:
if
element
in
dim_check_list
:
if
element
in
dim_check_list
:
dim_check_list
.
remove
(
element
)
dim_check_list
.
remove
(
element
)
else
:
else
:
raise
Value
Error
(
raise
DuplicatedShardingDimension
Error
(
f
"find an invalid sharding axis
{
element
}
in dim_partition_dict in tensor dimension
{
dim
}
."
)
f
"find an invalid sharding axis
{
element
}
in dim_partition_dict in tensor dimension
{
dim
}
."
)
# make sure that the dimension is not out of index
for
dim
in
self
.
dim_partition_dict
.
keys
():
if
dim
>=
len
(
self
.
entire_shape
):
raise
ShardingOutOfIndexError
(
f
"The dim_partition_dict specifies to shard dimension
{
dim
}
but the entire_shape only has
{
len
(
self
.
entire_shape
)
}
dimensions"
)
# make sure that the sharding for a dimension is divisible by the number of devices
for
dim
,
shard_list
in
self
.
dim_partition_dict
.
items
():
tensor_dim_size
=
self
.
entire_shape
[
dim
]
num_devices
=
1
for
element
in
shard_list
:
num_devices
*=
self
.
device_mesh
.
mesh_shape
[
element
]
if
tensor_dim_size
%
num_devices
!=
0
:
raise
ShardingNotDivisibleError
(
f
'The size of dimension at index
{
dim
}
is
{
tensor_dim_size
}
, it cannot be sharded over
{
num_devices
}
devices.'
)
def
convert_dict_to_shard_sequence
(
self
):
def
convert_dict_to_shard_sequence
(
self
):
'''
'''
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.
...
...
tests/test_auto_parallel/__init__.py
0 → 100644
View file @
eee84908
tests/test_auto_parallel/test_tensor_shard/__init__.py
0 → 100644
View file @
eee84908
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py
View file @
eee84908
from
cProfile
import
run
import
pytest
import
torch
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
torch.nn
as
nn
import
pytest
from
torch.fx
import
GraphModule
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor
import
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor
import
StrategiesConstructor
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
class
ConvModel
(
nn
.
Module
):
class
ConvModel
(
nn
.
Module
):
...
@@ -27,6 +30,7 @@ class ConvModel(nn.Module):
...
@@ -27,6 +30,7 @@ class ConvModel(nn.Module):
return
x
return
x
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_conv_handler
():
def
test_conv_handler
():
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py
View file @
eee84908
import
pytest
import
torch
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
torch.nn
as
nn
import
pytest
from
torch.fx
import
GraphModule
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor
import
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor
import
StrategiesConstructor
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
class
MatmulModel
(
nn
.
Module
):
class
MatmulModel
(
nn
.
Module
):
...
@@ -20,6 +21,7 @@ class MatmulModel(nn.Module):
...
@@ -20,6 +21,7 @@ class MatmulModel(nn.Module):
return
x
return
x
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_conv_handler
():
def
test_conv_handler
():
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/__init__.py
0 → 100644
View file @
eee84908
tests/test_auto_parallel/test_tensor_shard/test_node_handler/common.py
0 → 100644
View file @
eee84908
import
torch
from
colossalai.tensor.sharding_spec
import
ShardingSpec
def
is_sharding_spec_valid
(
sharding_spec
:
ShardingSpec
,
tensor
:
torch
.
Tensor
):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 2 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
#
"""
# make sure all dims are covered in sharding spec
sharding_len
=
len
(
sharding_spec
.
sharding_sequence
)
tensor_num_dim
=
tensor
.
dim
()
num_devices_in_col
=
sharding_spec
.
device_mesh
.
mesh_shape
[
0
]
num_devices_in_row
=
sharding_spec
.
device_mesh
.
mesh_shape
[
1
]
assert
sharding_len
==
tensor_num_dim
,
\
f
'The ShardingSpec (
{
sharding_spec
.
sharding_sequence
}
) is created for
{
sharding_len
}
-dimension tensor, but the given tensor is
{
tensor_num_dim
}
-dimension (
{
tensor
.
shape
}
).'
# make sure the sharding is valid for each dim
for
i
in
range
(
tensor_num_dim
):
dim_size
=
tensor
.
shape
[
i
]
dim_spec
=
sharding_spec
.
sharding_sequence
[
i
]
if
str
(
dim_spec
).
startswith
(
'S'
):
devices_str
=
str
(
dim_spec
).
lstrip
(
'S'
)
num_devices
=
1
if
'0'
in
devices_str
:
num_devices
*=
num_devices_in_col
if
'1'
in
devices_str
:
num_devices
*=
num_devices_in_row
assert
dim_size
>=
num_devices
and
dim_size
%
num_devices
==
0
,
\
f
'The dimension at index
{
i
}
has value
{
dim_size
}
, but it is sharded over
{
num_devices
}
devices.'
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
View file @
eee84908
...
@@ -6,12 +6,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationDa
...
@@ -6,12 +6,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationDa
StrategiesVector
)
StrategiesVector
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_module
import
linear
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.common
import
\
from
colossalai.tensor.sharding_spec
import
ShardingSpec
is_sharding_spec_valid
def
test_linear_module_handler
():
def
test_linear_module_handler
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
16
,
32
).
to
(
'meta'
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
16
,
32
).
to
(
'meta'
))
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
2
,
2
,
4
,
16
).
to
(
'meta'
)})
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
2
,
2
,
4
,
16
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
...
@@ -91,6 +92,12 @@ def test_linear_module_handler():
...
@@ -91,6 +92,12 @@ def test_linear_module_handler():
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bias'
)
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bias'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'_0'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'_0'
)
# make sure the sharding spec is valid
is_sharding_spec_valid
(
input_sharding_spec
,
torch
.
rand
(
2
,
2
,
4
,
16
))
is_sharding_spec_valid
(
weight_sharding_spec
,
model
.
get_parameter
(
'0.weight'
))
is_sharding_spec_valid
(
bias_sharding_spec
,
model
.
get_parameter
(
'0.bias'
))
is_sharding_spec_valid
(
output_sharding_spec
,
torch
.
rand
([
2
,
2
,
4
,
32
]))
# make sure the sharding matches across different operation data
# make sure the sharding matches across different operation data
assert
input_sharding_spec
.
sharding_sequence
[:
-
1
]
==
output_sharding_spec
.
sharding_sequence
[:
-
1
]
assert
input_sharding_spec
.
sharding_sequence
[:
-
1
]
==
output_sharding_spec
.
sharding_sequence
[:
-
1
]
assert
weight_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
assert
weight_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
...
@@ -101,7 +108,7 @@ def test_linear_module_handler():
...
@@ -101,7 +108,7 @@ def test_linear_module_handler():
def
test_linear_function_handler
():
def
test_linear_function_handler
():
model
=
nn
.
Linear
(
16
,
32
).
to
(
'meta'
)
model
=
nn
.
Linear
(
16
,
32
).
to
(
'meta'
)
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
16
).
to
(
'meta'
)})
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
2
,
2
,
4
,
16
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
...
@@ -117,11 +124,13 @@ def test_linear_function_handler():
...
@@ -117,11 +124,13 @@ def test_linear_function_handler():
# # check operation data mapping
# # check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
mapping
=
handler
.
get_operation_data_mapping
()
print
(
mapping
[
'input'
].
logical_shape
)
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
16
])
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
2
,
2
,
4
,
16
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
16
])
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
16
,
16
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
is_meta
...
@@ -137,7 +146,7 @@ def test_linear_function_handler():
...
@@ -137,7 +146,7 @@ def test_linear_function_handler():
assert
mapping
[
'output'
].
name
==
"linear"
assert
mapping
[
'output'
].
name
==
"linear"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
32
])
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
2
,
2
,
4
,
32
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
...
@@ -167,11 +176,18 @@ def test_linear_function_handler():
...
@@ -167,11 +176,18 @@ def test_linear_function_handler():
for
strategy
in
strategies_vector
:
for
strategy
in
strategies_vector
:
strategy
:
ShardingStrategy
strategy
:
ShardingStrategy
print
(
strategy
)
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'input_1'
)
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'input_1'
)
weight_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'weight'
)
weight_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'weight'
)
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bias'
)
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bias'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'linear'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'linear'
)
# make sure the sharding spec is valid
is_sharding_spec_valid
(
input_sharding_spec
,
torch
.
rand
(
2
,
2
,
4
,
16
))
is_sharding_spec_valid
(
weight_sharding_spec
,
model
.
get_parameter
(
'weight'
))
is_sharding_spec_valid
(
bias_sharding_spec
,
model
.
get_parameter
(
'bias'
))
is_sharding_spec_valid
(
output_sharding_spec
,
torch
.
rand
([
2
,
2
,
4
,
32
]))
# make sure the sharding matches across different operation data
# make sure the sharding matches across different operation data
assert
input_sharding_spec
.
sharding_sequence
[:
-
1
]
==
output_sharding_spec
.
sharding_sequence
[:
-
1
]
assert
input_sharding_spec
.
sharding_sequence
[:
-
1
]
==
output_sharding_spec
.
sharding_sequence
[:
-
1
]
assert
weight_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
assert
weight_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py
View file @
eee84908
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler.conv_handler
import
\
from
colossalai.auto_parallel.tensor_shard.node_handler.conv_handler
import
\
ConvFunctionHandler
ConvFunctionHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler
import
\
from
colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler
import
\
...
...
tests/test_tensor/test_sharded_linear.py
View file @
eee84908
from
functools
import
partial
from
lib2to3
import
pgen2
from
lib2to3
import
pgen2
import
colossalai
import
torch
import
pytest
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
import
colossalai
from
functools
import
partial
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
,
ProcessGroup
from
colossalai.nn._ops._utils
import
gather_forward_split_backward
from
colossalai.nn._ops._utils
import
gather_forward_split_backward
from
colossalai.tensor
import
ColoParameter
,
ColoTensor
,
ProcessGroup
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
...
@@ -18,7 +20,7 @@ def run_dist(rank, world_size, port):
...
@@ -18,7 +20,7 @@ def run_dist(rank, world_size, port):
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# create mlp vars
# create mlp vars
x
=
ColoTensor
.
from_torch_tensor
(
torch
.
rand
(
2
,
4
,
8
,
requires_grad
=
True
)).
cuda
()
x
=
ColoTensor
.
from_torch_tensor
(
torch
.
rand
(
4
,
4
,
8
,
requires_grad
=
True
)).
cuda
()
w
=
ColoParameter
.
from_torch_tensor
(
torch
.
rand
(
16
,
8
,
requires_grad
=
True
)).
cuda
()
w
=
ColoParameter
.
from_torch_tensor
(
torch
.
rand
(
16
,
8
,
requires_grad
=
True
)).
cuda
()
b
=
ColoParameter
.
from_torch_tensor
(
torch
.
rand
(
16
,
requires_grad
=
True
)).
cuda
()
b
=
ColoParameter
.
from_torch_tensor
(
torch
.
rand
(
16
,
requires_grad
=
True
)).
cuda
()
...
...
tests/test_tensor/test_sharding_spec.py
View file @
eee84908
import
torch
import
torch
from
colossalai.tensor.sharding_spec
import
_DimSpec
,
ShardingSpec
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
_DimSpec
def
test_sharding_spec
():
def
test_sharding_spec
():
...
@@ -11,7 +12,7 @@ def test_sharding_spec():
...
@@ -11,7 +12,7 @@ def test_sharding_spec():
# [8, 9, 10,11],
# [8, 9, 10,11],
# [12,13,14,15]]
# [12,13,14,15]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
entire_shape
=
torch
.
Size
((
4
,
8
,
6
))
entire_shape
=
torch
.
Size
((
16
,
8
,
6
))
dim_partition_dict
=
{
0
:
[
0
,
1
]}
dim_partition_dict
=
{
0
:
[
0
,
1
]}
# DistSpec:
# DistSpec:
# shard_sequence: S01,R,R
# shard_sequence: S01,R,R
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment