Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cdb7d5e7
Unverified
Commit
cdb7d5e7
authored
Oct 20, 2022
by
YuliangLiu0306
Committed by
GitHub
Oct 20, 2022
Browse files
[hotfix] autoparallel unit test (#1752)
parent
a4ce180e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
14 deletions
+16
-14
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py
...o_parallel/tensor_shard/deprecated/op_handler/__init__.py
+6
-5
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py
...st_tensor_shard/test_deprecated/test_deprecated_solver.py
+10
-9
No files found.
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py
View file @
cdb7d5e7
from
.operator_handler
import
OperatorHandler
from
.dot_handler
import
DotHandler
from
.conv_handler
import
ConvHandler
from
.batch_norm_handler
import
BatchNormHandler
from
.batch_norm_handler
import
BatchNormHandler
from
.reshape_handler
import
ReshapeHandler
from
.bcast_op_handler
import
BcastOpHandler
from
.bcast_op_handler
import
BcastOpHandler
from
.conv_handler
import
ConvHandler
from
.dot_handler
import
DotHandler
from
.embedding_handler
import
EmbeddingHandler
from
.embedding_handler
import
EmbeddingHandler
from
.layer_norm_handler
import
LayerNormHandler
from
.operator_handler
import
OperatorHandler
from
.reshape_handler
import
ReshapeHandler
from
.unary_elementwise_handler
import
UnaryElementwiseHandler
from
.unary_elementwise_handler
import
UnaryElementwiseHandler
from
.where_handler
import
WhereHandler
from
.where_handler
import
WhereHandler
__all__
=
[
__all__
=
[
'OperatorHandler'
,
'DotHandler'
,
'ConvHandler'
,
'BatchNormHandler'
,
'ReshapeHandler'
,
'BcastOpHandler'
,
'OperatorHandler'
,
'DotHandler'
,
'ConvHandler'
,
'BatchNormHandler'
,
'ReshapeHandler'
,
'BcastOpHandler'
,
'UnaryElementwiseHandler'
,
'EmbeddingHandler'
,
'WhereHandler'
'UnaryElementwiseHandler'
,
'EmbeddingHandler'
,
'WhereHandler'
,
'LayerNormHandler'
]
]
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py
View file @
cdb7d5e7
from
copy
import
deepcopy
import
pytest
import
torch
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
torch.nn
as
nn
import
pytest
from
torch.fx
import
GraphModule
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.auto_parallel.tensor_shard.deprecated
import
Solver
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor
import
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.deprecated.cost_graph
import
CostGraph
from
colossalai.auto_parallel.tensor_shard.deprecated.cost_graph
import
CostGraph
from
colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis
import
GraphAnalyser
from
copy
import
deepcopy
from
colossalai.auto_parallel.tensor_shard.deprecated
import
Solver
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor
import
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
...
@@ -60,7 +61,7 @@ def test_solver():
...
@@ -60,7 +61,7 @@ def test_solver():
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
solver_options
=
SolverOptions
(
fast
=
True
)
solver_options
=
SolverOptions
(
fast
=
True
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
shape_consistency_manager
,
solver_options
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment