Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cdb7d5e7
Unverified
Commit
cdb7d5e7
authored
Oct 20, 2022
by
YuliangLiu0306
Committed by
GitHub
Oct 20, 2022
Browse files
[hotfix] autoparallel unit test (#1752)
parent
a4ce180e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
14 deletions
+16
-14
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py
...o_parallel/tensor_shard/deprecated/op_handler/__init__.py
+6
-5
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py
...st_tensor_shard/test_deprecated/test_deprecated_solver.py
+10
-9
No files found.
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py
View file @
cdb7d5e7
from
.operator_handler
import
OperatorHandler
from
.dot_handler
import
DotHandler
from
.conv_handler
import
ConvHandler
from
.batch_norm_handler
import
BatchNormHandler
from
.reshape_handler
import
ReshapeHandler
from
.bcast_op_handler
import
BcastOpHandler
from
.conv_handler
import
ConvHandler
from
.dot_handler
import
DotHandler
from
.embedding_handler
import
EmbeddingHandler
from
.layer_norm_handler
import
LayerNormHandler
from
.operator_handler
import
OperatorHandler
from
.reshape_handler
import
ReshapeHandler
from
.unary_elementwise_handler
import
UnaryElementwiseHandler
from
.where_handler
import
WhereHandler
__all__
=
[
'OperatorHandler'
,
'DotHandler'
,
'ConvHandler'
,
'BatchNormHandler'
,
'ReshapeHandler'
,
'BcastOpHandler'
,
'UnaryElementwiseHandler'
,
'EmbeddingHandler'
,
'WhereHandler'
'UnaryElementwiseHandler'
,
'EmbeddingHandler'
,
'WhereHandler'
,
'LayerNormHandler'
]
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py
View file @
cdb7d5e7
from
copy
import
deepcopy
import
pytest
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
pytest
from
torch.fx
import
GraphModule
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor
import
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.deprecated
import
Solver
from
colossalai.auto_parallel.tensor_shard.deprecated.cost_graph
import
CostGraph
from
colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis
import
GraphAnalyser
from
copy
import
deepcopy
from
colossalai.auto_parallel.tensor_shard.deprecated
import
Solver
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor
import
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
...
...
@@ -60,7 +61,7 @@ def test_solver():
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
solver_options
=
SolverOptions
(
fast
=
True
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
shape_consistency_manager
,
solver_options
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment