Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
21d6a48f
Unverified
Commit
21d6a48f
authored
Feb 15, 2023
by
YuliangLiu0306
Committed by
GitHub
Feb 15, 2023
Browse files
[autoparallel] add shard option (#2696)
* [autoparallel] add shard option * polish
parent
5b24987f
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
176 additions
and
74 deletions
+176
-74
colossalai/auto_parallel/tensor_shard/initialize.py
colossalai/auto_parallel/tensor_shard/initialize.py
+60
-10
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
...salai/auto_parallel/tensor_shard/node_handler/__init__.py
+2
-3
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
...i/auto_parallel/tensor_shard/node_handler/node_handler.py
+20
-11
colossalai/auto_parallel/tensor_shard/node_handler/option.py
colossalai/auto_parallel/tensor_shard/node_handler/option.py
+0
-17
colossalai/auto_parallel/tensor_shard/options.py
colossalai/auto_parallel/tensor_shard/options.py
+49
-0
colossalai/auto_parallel/tensor_shard/solver/__init__.py
colossalai/auto_parallel/tensor_shard/solver/__init__.py
+1
-2
colossalai/auto_parallel/tensor_shard/solver/solver.py
colossalai/auto_parallel/tensor_shard/solver/solver.py
+1
-1
colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
...to_parallel/tensor_shard/solver/strategies_constructor.py
+21
-5
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
...test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
+2
-7
tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py
...st_auto_parallel/test_tensor_shard/test_metainfo/utils.py
+2
-1
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py
.../test_tensor_shard/test_node_handler/test_shard_option.py
+12
-2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
...uto_parallel/test_tensor_shard/test_node_handler/utils.py
+2
-1
tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py
..._parallel/test_tensor_shard/test_param_resharding_cost.py
+2
-7
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
..._parallel/test_tensor_shard/test_solver_with_resnet_v2.py
+2
-7
No files found.
colossalai/auto_parallel/tensor_shard/initialize.py
View file @
21d6a48f
...
...
@@ -8,14 +8,9 @@ from torch.fx.graph import Graph
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.options
import
DataloaderOption
,
ShardOption
,
SolverOptions
,
SolverPerference
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
CommAction
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.auto_parallel.tensor_shard.solver
import
CostGraph
,
GraphAnalyser
,
Solver
,
StrategiesConstructor
from
colossalai.device.alpha_beta_profiler
import
AlphaBetaProfiler
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.graph_module
import
ColoGraphModule
...
...
@@ -69,13 +64,43 @@ def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[f
pass
def
build_strategy_constructor
(
graph
:
Graph
,
device_mesh
:
DeviceMesh
):
def
build_strategy_constructor
(
graph
:
Graph
,
device_mesh
:
DeviceMesh
,
solver_preference
:
str
,
dataloader_option
:
str
,
shard_option
:
str
):
'''
This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler.
'''
solver_options
=
SolverOptions
()
if
solver_preference
==
'standard'
:
solver_preference
=
SolverPerference
.
STANDARD
elif
solver_preference
==
'tp'
:
solver_preference
=
SolverPerference
.
TP
elif
solver_preference
==
'dp'
:
solver_preference
=
SolverPerference
.
DP
else
:
raise
ValueError
(
f
'Invalid solver_preference:
{
solver_preference
}
'
)
if
dataloader_option
==
'replicated'
:
dataloader_option
=
DataloaderOption
.
REPLICATED
elif
dataloader_option
==
'distributed'
:
dataloader_option
=
DataloaderOption
.
DISTRIBUTED
else
:
raise
ValueError
(
f
'Invalid dataloader_option:
{
dataloader_option
}
'
)
if
shard_option
==
'standard'
:
shard_option
=
ShardOption
.
STANDARD
elif
shard_option
==
'shard'
:
shard_option
=
ShardOption
.
SHARD
elif
shard_option
==
'shard_last_axis'
:
shard_option
=
ShardOption
.
SHARD_LAST_AXIS
elif
shard_option
==
'full_shard'
:
shard_option
=
ShardOption
.
FULL_SHARD
else
:
raise
ValueError
(
f
'Invalid shard_option:
{
shard_option
}
'
)
solver_options
=
SolverOptions
(
solver_perference
=
solver_preference
,
dataloader_option
=
dataloader_option
,
shard_option
=
shard_option
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
...
...
@@ -183,6 +208,9 @@ def initialize_model(model: nn.Module,
device_mesh
:
DeviceMesh
,
memory_budget
:
float
=
-
1.0
,
overlap
:
bool
=
False
,
solver_preference
:
str
=
'standard'
,
dataloader_option
:
str
=
'replicated'
,
shard_option
:
str
=
'standard'
,
save_solver_solution
:
bool
=
False
,
load_solver_solution
:
bool
=
False
,
solution_path
:
str
=
None
,
...
...
@@ -198,6 +226,12 @@ def initialize_model(model: nn.Module,
the memory budget will be infinity.
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
backward computing.
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
...
...
@@ -212,7 +246,12 @@ def initialize_model(model: nn.Module,
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
gm
=
ColoGraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
strategies_constructor
=
build_strategy_constructor
(
graph
,
device_mesh
)
strategies_constructor
=
build_strategy_constructor
(
graph
,
device_mesh
,
solver_preference
=
solver_preference
,
dataloader_option
=
dataloader_option
,
shard_option
=
shard_option
)
if
load_solver_solution
:
solution
=
torch
.
load
(
solution_path
)
else
:
...
...
@@ -240,6 +279,9 @@ def autoparallelize(model: nn.Module,
alpha_beta_dict
:
Dict
[
Tuple
[
int
],
Tuple
[
float
]]
=
None
,
logical_mesh_shape
:
Tuple
[
int
]
=
None
,
logical_mesh_id
:
torch
.
Tensor
=
None
,
solver_preference
:
str
=
'standard'
,
dataloader_option
:
str
=
'replicated'
,
shard_option
:
str
=
'standard'
,
save_solver_solution
:
bool
=
False
,
load_solver_solution
:
bool
=
False
,
solver_solution_path
:
str
=
None
,
...
...
@@ -262,6 +304,12 @@ def autoparallelize(model: nn.Module,
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
...
...
@@ -280,6 +328,8 @@ def autoparallelize(model: nn.Module,
rst_to_unpack
=
initialize_model
(
model
,
meta_args
,
device_mesh
,
solver_preference
=
solver_preference
,
dataloader_option
=
dataloader_option
,
save_solver_solution
=
save_solver_solution
,
load_solver_solution
=
load_solver_solution
,
solution_path
=
solver_solution_path
,
...
...
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
View file @
21d6a48f
...
...
@@ -11,7 +11,6 @@ from .layer_norm_handler import LayerNormModuleHandler
from
.linear_handler
import
LinearFunctionHandler
,
LinearModuleHandler
from
.matmul_handler
import
MatMulHandler
from
.normal_pooling_handler
import
NormPoolingHandler
from
.option
import
ShardOption
from
.output_handler
import
OutputHandler
from
.permute_handler
import
PermuteHandler
from
.placeholder_handler
import
PlaceholderHandler
...
...
@@ -31,6 +30,6 @@ __all__ = [
'UnaryElementwiseHandler'
,
'DefaultReshapeHandler'
,
'PlaceholderHandler'
,
'OutputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'GetItemHandler'
,
'GetattrHandler'
,
'ViewHandler'
,
'PermuteHandler'
,
'TensorConstructorHandler'
,
'EmbeddingModuleHandler'
,
'EmbeddingFunctionHandler'
,
'SumHandler'
,
'SoftmaxHandler'
,
'
ShardOption
'
,
'TransposeHandler'
,
'SplitHandler'
'EmbeddingModuleHandler'
,
'EmbeddingFunctionHandler'
,
'SumHandler'
,
'SoftmaxHandler'
,
'
TransposeHandler
'
,
'SplitHandler'
]
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
View file @
21d6a48f
...
...
@@ -5,7 +5,7 @@ import torch
from
torch.fx.node
import
Node
from
colossalai.auto_parallel.meta_profiler.metainfo
import
MetaInfo
,
meta_register
from
colossalai.auto_parallel.tensor_shard.
node_handler.
option
import
ShardOption
from
colossalai.auto_parallel.tensor_shard.option
s
import
ShardOption
,
SolverPerference
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
...
...
@@ -32,19 +32,19 @@ class NodeHandler(ABC):
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
'''
def
__init__
(
self
,
node
:
Node
,
device_mesh
:
DeviceMesh
,
strategies_vector
:
StrategiesVector
,
shard_option
:
ShardOption
=
ShardOption
.
STANDARD
,
)
->
None
:
def
__init__
(
self
,
node
:
Node
,
device_mesh
:
DeviceMesh
,
strategies_vector
:
StrategiesVector
,
shard_option
:
ShardOption
=
ShardOption
.
STANDARD
,
solver_perference
:
SolverPerference
=
SolverPerference
.
STANDARD
)
->
None
:
self
.
node
=
node
self
.
predecessor_node
=
list
(
node
.
_input_nodes
.
keys
())
self
.
successor_node
=
list
(
node
.
users
.
keys
())
self
.
device_mesh
=
device_mesh
self
.
strategies_vector
=
strategies_vector
self
.
shard_option
=
shard_option
self
.
solver_perference
=
solver_perference
def
update_resharding_cost
(
self
,
strategy
:
ShardingStrategy
)
->
None
:
"""
...
...
@@ -187,15 +187,24 @@ class NodeHandler(ABC):
remove_strategy_list
=
[]
for
strategy
in
self
.
strategies_vector
:
shard_level
=
0
shard_axis_list
=
[]
last_axis
=
len
(
self
.
device_mesh
.
mesh_shape
)
-
1
for
op_data
,
sharding_spec
in
strategy
.
sharding_specs
.
items
():
if
op_data
.
data
is
not
None
and
isinstance
(
op_data
.
data
,
torch
.
Tensor
):
for
dim
,
shard_axis
in
sharding_spec
.
dim_partition_dict
.
items
():
shard_level
+=
len
(
shard_axis
)
for
dim
,
shard_axes
in
sharding_spec
.
dim_partition_dict
.
items
():
for
shard_axis
in
shard_axes
:
if
shard_axis
not
in
shard_axis_list
:
shard_axis_list
.
append
(
shard_axis
)
shard_level
=
len
(
shard_axis_list
)
using_last_axis
=
last_axis
in
shard_axis_list
or
-
1
in
shard_axis_list
if
self
.
shard_option
==
ShardOption
.
SHARD
and
shard_level
==
0
:
remove_strategy_list
.
append
(
strategy
)
if
self
.
shard_option
==
ShardOption
.
FULL_SHARD
and
shard_level
<=
1
:
remove_strategy_list
.
append
(
strategy
)
if
self
.
shard_option
==
ShardOption
.
SHARD_LAST_AXIS
:
if
shard_level
!=
1
or
using_last_axis
==
False
:
remove_strategy_list
.
append
(
strategy
)
for
strategy
in
remove_strategy_list
:
self
.
strategies_vector
.
remove
(
strategy
)
...
...
colossalai/auto_parallel/tensor_shard/node_handler/option.py
deleted
100644 → 0
View file @
5b24987f
from
enum
import
Enum
__all__
=
[
'ShardOption'
]
class
ShardOption
(
Enum
):
"""
This enum class is to define the shard level required in node strategies.
Notes:
STANDARD: We do not add any extra shard requirements.
SHARD: We require the node to be shard using at least one device mesh axis.
FULL_SHARD: We require the node to be shard using all device mesh axes.
"""
STANDARD
=
0
SHARD
=
1
FULL_SHARD
=
2
colossalai/auto_parallel/tensor_shard/
solver/
options.py
→
colossalai/auto_parallel/tensor_shard/options.py
View file @
21d6a48f
from
dataclasses
import
dataclass
from
enum
import
Enum
__all__
=
[
'SolverOptions'
]
__all__
=
[
'SolverOptions'
,
'SolverPerference'
,
'DataloaderOption'
,
'ShardOption'
]
class
SolverPerference
(
Enum
):
...
...
@@ -13,6 +13,24 @@ class SolverPerference(Enum):
TP
=
2
class
ShardOption
(
Enum
):
"""
This enum class is to define the shard level required in node strategies.
Notes:
STANDARD: We do not add any extra shard requirements.
SHARD: We require the node to be shard using at least one device mesh axis.
SHARD_ONE_AXIS: We require the node to be shard using the last device mesh axis.
FULL_SHARD: We require the node to be shard using all device mesh axes.
TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
"""
STANDARD
=
0
SHARD
=
1
SHARD_LAST_AXIS
=
2
FULL_SHARD
=
3
class
DataloaderOption
(
Enum
):
"""
This enum class is to define the dataloader option.
...
...
@@ -28,3 +46,4 @@ class SolverOptions:
"""
solver_perference
:
SolverPerference
=
SolverPerference
.
STANDARD
dataloader_option
:
DataloaderOption
=
DataloaderOption
.
REPLICATED
shard_option
:
ShardOption
=
ShardOption
.
STANDARD
colossalai/auto_parallel/tensor_shard/solver/__init__.py
View file @
21d6a48f
from
.cost_graph
import
CostGraph
from
.graph_analysis
import
GraphAnalyser
from
.options
import
SolverOptions
from
.solver
import
Solver
from
.strategies_constructor
import
StrategiesConstructor
__all__
=
[
'GraphAnalyser'
,
'Solver'
,
'StrategiesConstructor'
,
'CostGraph'
,
'SolverOptions'
]
__all__
=
[
'GraphAnalyser'
,
'Solver'
,
'StrategiesConstructor'
,
'CostGraph'
]
colossalai/auto_parallel/tensor_shard/solver/solver.py
View file @
21d6a48f
...
...
@@ -33,7 +33,7 @@ class Solver:
solution_numbers
:
int
=
1
,
forward_only
:
bool
=
False
,
memory_increasing_coefficient
:
float
=
1.3
,
verbose
=
Tru
e
):
verbose
=
Fals
e
):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
...
...
colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
View file @
21d6a48f
...
...
@@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVe
from
colossalai.auto_parallel.tensor_shard.utils
import
generate_resharding_costs
,
generate_sharding_spec
from
colossalai.device.device_mesh
import
DeviceMesh
from
.options
import
DataloaderOption
,
SolverOptions
from
.
.options
import
DataloaderOption
,
SolverOptions
__all__
=
[
'StrategiesConstructor'
]
...
...
@@ -101,7 +101,11 @@ class StrategiesConstructor:
# get_attr node
elif
node
.
op
==
'get_attr'
:
getattr_handler
=
GetattrHandler
(
node
,
self
.
device_mesh
,
strategies_vector
)
getattr_handler
=
GetattrHandler
(
node
,
self
.
device_mesh
,
strategies_vector
,
shard_option
=
self
.
solver_options
.
shard_option
,
solver_perference
=
self
.
solver_options
.
solver_perference
)
getattr_handler
.
register_strategy
()
# call_module node
...
...
@@ -109,7 +113,11 @@ class StrategiesConstructor:
target
=
node
.
target
submod
=
self
.
root_module
.
get_submodule
(
target
)
submod_type
=
type
(
submod
)
handler
=
operator_registry
.
get
(
submod_type
)(
node
,
self
.
device_mesh
,
strategies_vector
)
handler
=
operator_registry
.
get
(
submod_type
)(
node
,
self
.
device_mesh
,
strategies_vector
,
shard_option
=
self
.
solver_options
.
shard_option
,
solver_perference
=
self
.
solver_options
.
solver_perference
)
handler
.
register_strategy
()
# attach metainfo_vector to node
if
hasattr
(
handler
,
'metainfo_vector'
):
...
...
@@ -118,7 +126,11 @@ class StrategiesConstructor:
# call_function node
elif
node
.
op
==
'call_function'
:
target
=
node
.
target
handler
=
operator_registry
.
get
(
target
)(
node
,
self
.
device_mesh
,
strategies_vector
)
handler
=
operator_registry
.
get
(
target
)(
node
,
self
.
device_mesh
,
strategies_vector
,
shard_option
=
self
.
solver_options
.
shard_option
,
solver_perference
=
self
.
solver_options
.
solver_perference
)
handler
.
register_strategy
()
# attach metainfo_vector to node
if
hasattr
(
handler
,
'metainfo_vector'
):
...
...
@@ -127,7 +139,11 @@ class StrategiesConstructor:
# call_method node
elif
node
.
op
==
'call_method'
:
method
=
getattr
(
node
.
args
[
0
].
_meta_data
.
__class__
,
node
.
target
)
handler
=
operator_registry
.
get
(
method
)(
node
,
self
.
device_mesh
,
strategies_vector
)
handler
=
operator_registry
.
get
(
method
)(
node
,
self
.
device_mesh
,
strategies_vector
,
shard_option
=
self
.
solver_options
.
shard_option
,
solver_perference
=
self
.
solver_options
.
solver_perference
)
handler
.
register_strategy
()
# attach metainfo_vector to node
if
hasattr
(
handler
,
'metainfo_vector'
):
...
...
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
View file @
21d6a48f
...
...
@@ -4,13 +4,8 @@ import transformers
from
torch.fx
import
GraphModule
from
colossalai.auto_parallel.tensor_shard.constants
import
BATCHNORM_MODULE_OP
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.auto_parallel.tensor_shard.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.solver
import
CostGraph
,
GraphAnalyser
,
Solver
,
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
...
...
tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py
View file @
21d6a48f
...
...
@@ -7,8 +7,9 @@ from torch.fx import GraphModule
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationDataType
,
TrainCycleItem
from
colossalai.auto_parallel.tensor_shard.solver
import
SolverOptions
,
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.solver
import
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py
View file @
21d6a48f
...
...
@@ -5,7 +5,7 @@ import torch.multiprocessing as mp
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler
import
LinearFunctionHandler
from
colossalai.auto_parallel.tensor_shard.
node_handler.
option
import
ShardOption
from
colossalai.auto_parallel.tensor_shard.option
s
import
ShardOption
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
...
...
@@ -49,6 +49,15 @@ def check_shard_option(shard_option):
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
if
shard_option
==
ShardOption
.
SHARD_LAST_AXIS
:
# RR = RS x SR
assert
'RR = RS1 x S1R'
in
strategy_name_list
# RS= RR x RS
assert
'RS1 = RR x RS1'
in
strategy_name_list
return
# SS = SR x RS
assert
'S1S0 = S1R x RS0_0'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_1'
in
strategy_name_list
...
...
@@ -104,7 +113,8 @@ def check_shard_option(shard_option):
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_shard_option
():
for
shard_option
in
[
ShardOption
.
STANDARD
,
ShardOption
.
SHARD
,
ShardOption
.
FULL_SHARD
]:
# for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]:
for
shard_option
in
[
ShardOption
.
SHARD_LAST_AXIS
]:
check_shard_option
(
shard_option
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
View file @
21d6a48f
...
...
@@ -6,7 +6,8 @@ from torch.fx import GraphModule
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.solver
import
SolverOptions
,
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.solver
import
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.solver.cost_graph
import
CostGraph
from
colossalai.auto_parallel.tensor_shard.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.solver.solver
import
Solver
...
...
tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py
View file @
21d6a48f
import
torch
from
colossalai.auto_parallel.tensor_shard.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationDataType
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.auto_parallel.tensor_shard.solver
import
CostGraph
,
GraphAnalyser
,
Solver
,
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
...
...
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
View file @
21d6a48f
...
...
@@ -3,13 +3,8 @@ from torch.fx import GraphModule
from
torchvision.models
import
resnet50
from
colossalai.auto_parallel.tensor_shard.constants
import
BATCHNORM_MODULE_OP
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.auto_parallel.tensor_shard.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.solver
import
CostGraph
,
GraphAnalyser
,
Solver
,
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment