Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
eee84908
Unverified
Commit
eee84908
authored
Oct 19, 2022
by
Frank Lee
Committed by
GitHub
Oct 19, 2022
Browse files
[autoparallel] handled illegal sharding strategy (#1728)
* [autoparallel] handled illegal sharding strategy * polish code
parent
cbe9a4cb
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
175 additions
and
61 deletions
+175
-61
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
.../tensor_shard/node_handler/strategy/strategy_generator.py
+26
-5
colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
...hard/node_handler/strategy/unary_elementwise_generator.py
+2
-6
colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
...lel/tensor_shard/node_handler/strategy/where_generator.py
+2
-6
colossalai/auto_parallel/tensor_shard/utils/__init__.py
colossalai/auto_parallel/tensor_shard/utils/__init__.py
+2
-2
colossalai/auto_parallel/tensor_shard/utils/misc.py
colossalai/auto_parallel/tensor_shard/utils/misc.py
+12
-7
colossalai/tensor/sharding_spec.py
colossalai/tensor/sharding_spec.py
+47
-12
tests/test_auto_parallel/__init__.py
tests/test_auto_parallel/__init__.py
+0
-0
tests/test_auto_parallel/test_tensor_shard/__init__.py
tests/test_auto_parallel/test_tensor_shard/__init__.py
+0
-0
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py
...st_deprecated_op_handler/test_deprecated_bcast_handler.py
+7
-3
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py
...est_deprecated_op_handler/test_deprecated_bcast_matmul.py
+5
-3
tests/test_auto_parallel/test_tensor_shard/test_node_handler/__init__.py
..._parallel/test_tensor_shard/test_node_handler/__init__.py
+0
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/common.py
...to_parallel/test_tensor_shard/test_node_handler/common.py
+37
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
...est_tensor_shard/test_node_handler/test_linear_handler.py
+22
-6
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py
...hard/test_node_handler/test_unary_element_wise_handler.py
+0
-1
tests/test_tensor/test_sharded_linear.py
tests/test_tensor/test_sharded_linear.py
+10
-8
tests/test_tensor/test_sharding_spec.py
tests/test_tensor/test_sharding_spec.py
+3
-2
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
View file @
eee84908
...
...
@@ -4,13 +4,12 @@ from functools import reduce
from
typing
import
Any
,
Dict
,
List
,
Union
import
torch
from
torch.fx
import
Node
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
ShardingStrategy
,
TrainCycleItem
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
,
CommSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
torch.fx
import
Node
class
StrategyGenerator
(
ABC
):
...
...
@@ -24,6 +23,9 @@ class StrategyGenerator(ABC):
self
.
op_data
=
operation_data_mapping
self
.
device_mesh
=
device_mesh
# validate the whether operation data is of desired value
self
.
validate
()
@
property
def
has_bias
(
self
):
"""
...
...
@@ -102,9 +104,9 @@ class StrategyGenerator(ABC):
comm_cost
=
TrainCycleItem
(
fwd
=
0
,
bwd
=
0
,
total
=
0
)
def
_compute_and_add
(
data
:
OperationData
,
comm_spec
:
CommSpec
):
def
_compute_and_add
(
op_
data
:
OperationData
,
comm_spec
:
CommSpec
):
num_ele_in_comm
=
comm_spec
.
get_comm_cost
()
dtype
=
op
erand
.
data
.
dtype
dtype
=
op
_data
.
data
.
dtype
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
for
phase
,
cost
in
num_ele_in_comm
.
items
():
num_ele_in_comm
[
phase
]
=
num_ele_in_comm
[
phase
]
*
size_per_elem_bytes
...
...
@@ -151,11 +153,30 @@ class StrategyGenerator(ABC):
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
reduce
(
operator
.
mul
,
sharded_shape
)
*
size_per_elem_bytes
@
abstractmethod
def
generate
(
self
)
->
List
[
ShardingStrategy
]:
"""
Generate all possible sharding strategies for this operation.
"""
strategies
=
self
.
collate_strategies
()
# some strategies may be None as ignore_sharding_exception may return None
# when ShardingSpecException occurs.
# thus, remove those None values
strategies
=
[
strategy
for
strategy
in
strategies
if
strategy
]
# update the costs
# update mete info on cost
# these update methods are all in-place, the default method will do nothing
# the cost info will only be added if the child class overrides these methods
for
strategy
in
strategies
:
self
.
update_communication_cost
(
strategy
)
self
.
update_compute_cost
(
strategy
)
self
.
update_memory_cost
(
strategy
)
return
strategies
@
abstractmethod
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
pass
@
abstractmethod
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
View file @
eee84908
import
copy
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
)
...
...
@@ -48,7 +49,7 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
def
generate
(
self
)
:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]
:
strategy_list
=
[]
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
...
...
@@ -73,9 +74,4 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
for
strategy
in
strategy_list
:
self
.
update_communication_cost
(
strategy
)
self
.
update_compute_cost
(
strategy
)
self
.
update_memory_cost
(
strategy
)
return
strategy_list
colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
View file @
eee84908
import
copy
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
)
from
colossalai.auto_parallel.tensor_shard.utils
import
(
enumerate_all_possible_1d_sharding
,
...
...
@@ -78,7 +79,7 @@ class WhereGenerator(StrategyGenerator):
return
dim_partition_list
def
generate
(
self
)
:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]
:
'''
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
'''
...
...
@@ -90,9 +91,4 @@ class WhereGenerator(StrategyGenerator):
strategy
=
self
.
_generate_strategy_with_dim_partition
(
dim_partition
)
strategy_list
.
append
(
strategy
)
for
strategy
in
strategy_list
:
self
.
update_communication_cost
(
strategy
)
self
.
update_compute_cost
(
strategy
)
self
.
update_memory_cost
(
strategy
)
return
strategy_list
colossalai/auto_parallel/tensor_shard/utils/__init__.py
View file @
eee84908
from
.broadcast
import
(
BroadcastType
,
get_broadcast_shape
,
is_broadcastable
,
recover_sharding_spec_for_broadcast_shape
)
from
.factory
import
generate_resharding_costs
,
generate_sharding_spec
from
.misc
import
exception
_handler
from
.misc
import
ignore_sharding_
exception
from
.sharding
import
(
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
,
generate_sharding_size
,
switch_partition_dim
,
update_partition_dim
)
__all__
=
[
'BroadcastType'
,
'get_broadcast_shape'
,
'is_broadcastable'
,
'recover_sharding_spec_for_broadcast_shape'
,
'generate_resharding_costs'
,
'generate_sharding_spec'
,
'exception
_handler
'
,
'switch_partition_dim'
,
'generate_resharding_costs'
,
'generate_sharding_spec'
,
'
ignore_sharding_
exception'
,
'switch_partition_dim'
,
'update_partition_dim'
,
'enumerate_all_possible_1d_sharding'
,
'enumerate_all_possible_2d_sharding'
,
'generate_sharding_size'
]
colossalai/auto_parallel/tensor_shard/utils/misc.py
View file @
eee84908
import
functools
import
warnings
__all__
=
[
'exception_handler'
]
from
colossalai.logging
import
get_dist_logger
from
colossalai.tensor.sharding_spec
import
ShardingSpecException
__all__
=
[
'ignore_sharding_exception'
]
def
exception_handler
(
func
):
def
ignore_sharding_exception
(
func
):
"""
A function wrapper to handle the AssertionError in the function.
A function wrapper to handle the ShardingSpecException in the function.
If ShardingSpecException occurs, this function will return None.
Usage:
# mute the assertion error in the function
@exception
_handler
@
ignore_sharding_
exception
def do_something():
...
"""
...
...
@@ -18,9 +21,11 @@ def exception_handler(func):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
try
:
logger
=
get_dist_logger
()
rst
=
func
(
*
args
,
**
kwargs
)
return
rst
except
AssertionError
as
e
:
warnings
.
warn
(
f
'
{
e
}
'
)
except
ShardingSpecException
as
e
:
logger
.
debug
(
e
)
return
None
return
wrapper
colossalai/tensor/sharding_spec.py
View file @
eee84908
import
torch
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.utils
import
all_gather_simulator
,
all_to_all_simulator
,
shard_simulator
import
operator
from
copy
import
deepcopy
from
enum
import
Enum
from
functools
import
reduce
import
operator
import
torch
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.utils
import
(
all_gather_simulator
,
all_to_all_simulator
,
shard_simulator
)
__all__
=
[
'_DimSpec'
,
'ShardingException'
,
'ShardingSpec'
]
...
...
@@ -138,7 +140,19 @@ class _DimSpec:
return
difference
class
ShardingException
(
Exception
):
class
ShardingSpecException
(
Exception
):
pass
class
ShardingOutOfIndexError
(
ShardingSpecException
):
pass
class
DuplicatedShardingDimensionError
(
ShardingSpecException
):
pass
class
ShardingNotDivisibleError
(
ShardingSpecException
):
pass
...
...
@@ -156,7 +170,11 @@ class ShardingSpec:
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
'''
def
__init__
(
self
,
device_mesh
,
entire_shape
,
dim_partition_dict
=
None
,
sharding_sequence
=
None
):
def
__init__
(
self
,
device_mesh
:
DeviceMesh
,
entire_shape
:
torch
.
Size
,
dim_partition_dict
=
None
,
sharding_sequence
=
None
):
self
.
device_mesh
=
device_mesh
self
.
entire_shape
=
entire_shape
self
.
dim_partition_dict
=
dim_partition_dict
...
...
@@ -174,19 +192,36 @@ class ShardingSpec:
return
' '
.
join
(
res_list
)
def
_sanity_check
(
self
):
'''
In sanity check, we need make sure all axes in logical device mesh only be used
once.
'''
dim_check_list
=
[
i
for
i
in
range
(
self
.
device_mesh
.
logical_mesh_id
.
dim
())]
# make sure all axes in logical device mesh only be used once
dim_check_list
=
list
(
range
(
self
.
device_mesh
.
logical_mesh_id
.
dim
()))
for
dim
,
shard_list
in
self
.
dim_partition_dict
.
items
():
for
element
in
shard_list
:
if
element
in
dim_check_list
:
dim_check_list
.
remove
(
element
)
else
:
raise
Value
Error
(
raise
DuplicatedShardingDimension
Error
(
f
"find an invalid sharding axis
{
element
}
in dim_partition_dict in tensor dimension
{
dim
}
."
)
# make sure that the dimension is not out of index
for
dim
in
self
.
dim_partition_dict
.
keys
():
if
dim
>=
len
(
self
.
entire_shape
):
raise
ShardingOutOfIndexError
(
f
"The dim_partition_dict specifies to shard dimension
{
dim
}
but the entire_shape only has
{
len
(
self
.
entire_shape
)
}
dimensions"
)
# make sure that the sharding for a dimension is divisible by the number of devices
for
dim
,
shard_list
in
self
.
dim_partition_dict
.
items
():
tensor_dim_size
=
self
.
entire_shape
[
dim
]
num_devices
=
1
for
element
in
shard_list
:
num_devices
*=
self
.
device_mesh
.
mesh_shape
[
element
]
if
tensor_dim_size
%
num_devices
!=
0
:
raise
ShardingNotDivisibleError
(
f
'The size of dimension at index
{
dim
}
is
{
tensor_dim_size
}
, it cannot be sharded over
{
num_devices
}
devices.'
)
def
convert_dict_to_shard_sequence
(
self
):
'''
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.
...
...
tests/test_auto_parallel/__init__.py
0 → 100644
View file @
eee84908
tests/test_auto_parallel/test_tensor_shard/__init__.py
0 → 100644
View file @
eee84908
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py
View file @
eee84908
from
cProfile
import
run
import
pytest
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
pytest
from
torch.fx
import
GraphModule
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor
import
StrategiesConstructor
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
class
ConvModel
(
nn
.
Module
):
...
...
@@ -27,6 +30,7 @@ class ConvModel(nn.Module):
return
x
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_conv_handler
():
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py
View file @
eee84908
import
pytest
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
pytest
from
torch.fx
import
GraphModule
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor
import
StrategiesConstructor
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
class
MatmulModel
(
nn
.
Module
):
...
...
@@ -20,6 +21,7 @@ class MatmulModel(nn.Module):
return
x
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_conv_handler
():
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/__init__.py
0 → 100644
View file @
eee84908
tests/test_auto_parallel/test_tensor_shard/test_node_handler/common.py
0 → 100644
View file @
eee84908
import
torch
from
colossalai.tensor.sharding_spec
import
ShardingSpec
def
is_sharding_spec_valid
(
sharding_spec
:
ShardingSpec
,
tensor
:
torch
.
Tensor
):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 2 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
#
"""
# make sure all dims are covered in sharding spec
sharding_len
=
len
(
sharding_spec
.
sharding_sequence
)
tensor_num_dim
=
tensor
.
dim
()
num_devices_in_col
=
sharding_spec
.
device_mesh
.
mesh_shape
[
0
]
num_devices_in_row
=
sharding_spec
.
device_mesh
.
mesh_shape
[
1
]
assert
sharding_len
==
tensor_num_dim
,
\
f
'The ShardingSpec (
{
sharding_spec
.
sharding_sequence
}
) is created for
{
sharding_len
}
-dimension tensor, but the given tensor is
{
tensor_num_dim
}
-dimension (
{
tensor
.
shape
}
).'
# make sure the sharding is valid for each dim
for
i
in
range
(
tensor_num_dim
):
dim_size
=
tensor
.
shape
[
i
]
dim_spec
=
sharding_spec
.
sharding_sequence
[
i
]
if
str
(
dim_spec
).
startswith
(
'S'
):
devices_str
=
str
(
dim_spec
).
lstrip
(
'S'
)
num_devices
=
1
if
'0'
in
devices_str
:
num_devices
*=
num_devices_in_col
if
'1'
in
devices_str
:
num_devices
*=
num_devices_in_row
assert
dim_size
>=
num_devices
and
dim_size
%
num_devices
==
0
,
\
f
'The dimension at index
{
i
}
has value
{
dim_size
}
, but it is sharded over
{
num_devices
}
devices.'
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
View file @
eee84908
...
...
@@ -6,12 +6,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationDa
StrategiesVector
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_module
import
linear
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.common
import
\
is_sharding_spec_valid
def
test_linear_module_handler
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
16
,
32
).
to
(
'meta'
))
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
2
,
2
,
4
,
16
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
...
...
@@ -91,6 +92,12 @@ def test_linear_module_handler():
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bias'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'_0'
)
# make sure the sharding spec is valid
is_sharding_spec_valid
(
input_sharding_spec
,
torch
.
rand
(
2
,
2
,
4
,
16
))
is_sharding_spec_valid
(
weight_sharding_spec
,
model
.
get_parameter
(
'0.weight'
))
is_sharding_spec_valid
(
bias_sharding_spec
,
model
.
get_parameter
(
'0.bias'
))
is_sharding_spec_valid
(
output_sharding_spec
,
torch
.
rand
([
2
,
2
,
4
,
32
]))
# make sure the sharding matches across different operation data
assert
input_sharding_spec
.
sharding_sequence
[:
-
1
]
==
output_sharding_spec
.
sharding_sequence
[:
-
1
]
assert
weight_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
...
...
@@ -101,7 +108,7 @@ def test_linear_module_handler():
def
test_linear_function_handler
():
model
=
nn
.
Linear
(
16
,
32
).
to
(
'meta'
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
16
).
to
(
'meta'
)})
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
2
,
2
,
4
,
16
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
...
...
@@ -117,11 +124,13 @@ def test_linear_function_handler():
# # check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
print
(
mapping
[
'input'
].
logical_shape
)
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
16
])
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
2
,
2
,
4
,
16
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
16
])
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
16
,
16
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
...
...
@@ -137,7 +146,7 @@ def test_linear_function_handler():
assert
mapping
[
'output'
].
name
==
"linear"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
32
])
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
2
,
2
,
4
,
32
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
...
...
@@ -167,11 +176,18 @@ def test_linear_function_handler():
for
strategy
in
strategies_vector
:
strategy
:
ShardingStrategy
print
(
strategy
)
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'input_1'
)
weight_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'weight'
)
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bias'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'linear'
)
# make sure the sharding spec is valid
is_sharding_spec_valid
(
input_sharding_spec
,
torch
.
rand
(
2
,
2
,
4
,
16
))
is_sharding_spec_valid
(
weight_sharding_spec
,
model
.
get_parameter
(
'weight'
))
is_sharding_spec_valid
(
bias_sharding_spec
,
model
.
get_parameter
(
'bias'
))
is_sharding_spec_valid
(
output_sharding_spec
,
torch
.
rand
([
2
,
2
,
4
,
32
]))
# make sure the sharding matches across different operation data
assert
input_sharding_spec
.
sharding_sequence
[:
-
1
]
==
output_sharding_spec
.
sharding_sequence
[:
-
1
]
assert
weight_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py
View file @
eee84908
import
torch
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler.conv_handler
import
\
ConvFunctionHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler
import
\
...
...
tests/test_tensor/test_sharded_linear.py
View file @
eee84908
from
functools
import
partial
from
lib2to3
import
pgen2
import
colossalai
import
torch
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn.functional
as
F
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
functools
import
partial
import
colossalai
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
,
ProcessGroup
from
colossalai.nn._ops._utils
import
gather_forward_split_backward
from
colossalai.tensor
import
ColoParameter
,
ColoTensor
,
ProcessGroup
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
def
run_dist
(
rank
,
world_size
,
port
):
...
...
@@ -18,7 +20,7 @@ def run_dist(rank, world_size, port):
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# create mlp vars
x
=
ColoTensor
.
from_torch_tensor
(
torch
.
rand
(
2
,
4
,
8
,
requires_grad
=
True
)).
cuda
()
x
=
ColoTensor
.
from_torch_tensor
(
torch
.
rand
(
4
,
4
,
8
,
requires_grad
=
True
)).
cuda
()
w
=
ColoParameter
.
from_torch_tensor
(
torch
.
rand
(
16
,
8
,
requires_grad
=
True
)).
cuda
()
b
=
ColoParameter
.
from_torch_tensor
(
torch
.
rand
(
16
,
requires_grad
=
True
)).
cuda
()
...
...
tests/test_tensor/test_sharding_spec.py
View file @
eee84908
import
torch
from
colossalai.tensor.sharding_spec
import
_DimSpec
,
ShardingSpec
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
_DimSpec
def
test_sharding_spec
():
...
...
@@ -11,7 +12,7 @@ def test_sharding_spec():
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
entire_shape
=
torch
.
Size
((
4
,
8
,
6
))
entire_shape
=
torch
.
Size
((
16
,
8
,
6
))
dim_partition_dict
=
{
0
:
[
0
,
1
]}
# DistSpec:
# shard_sequence: S01,R,R
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment