Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
93b788b9
Commit
93b788b9
authored
Feb 15, 2023
by
binmakeswell
Browse files
Merge branch 'main' into fix/format
parents
2fd528b9
1dc003c1
Changes
146
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
100 additions
and
106 deletions
+100
-106
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py
...test_tensor_shard/test_node_handler/test_split_handler.py
+44
-44
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
.../test_tensor_shard/test_node_handler/test_view_handler.py
+49
-46
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
...uto_parallel/test_tensor_shard/test_node_handler/utils.py
+2
-1
tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py
..._parallel/test_tensor_shard/test_param_resharding_cost.py
+2
-7
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
..._parallel/test_tensor_shard/test_solver_with_resnet_v2.py
+2
-7
version.txt
version.txt
+1
-1
No files found.
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py
View file @
93b788b9
...
...
@@ -198,54 +198,54 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
if
model_cls
.
__name__
==
'LinearSplitModel'
:
if
split_dim
==
0
:
assert
'[R, R, R, S1]_
0
'
in
strategy_name_list
assert
'[R, S0, R, S1]_1'
in
strategy_name_list
assert
'[R, R, S0, S1]_
2
'
in
strategy_name_list
assert
'[R, R, R, S0]_
3
'
in
strategy_name_list
assert
'[R, S1, R, S0]_
4
'
in
strategy_name_list
assert
'[R, R, S1, S0]_
5
'
in
strategy_name_list
assert
'[R, R, R, R]_
6
'
in
strategy_name_list
assert
'[R, S0, R, R]_
7
'
in
strategy_name_list
assert
'[R, R, S0, R]_
8
'
in
strategy_name_list
assert
'[R, R, R, R]_
9
'
in
strategy_name_list
assert
'[R, S1, R, R]_1
0
'
in
strategy_name_list
assert
'[R, R, S1, R]_
11
'
in
strategy_name_list
assert
'[R, R, R, S1]_1
2
'
in
strategy_name_list
assert
'[R, R, R, S0]_
13
'
in
strategy_name_list
assert
'[R, R, R, R]_
14
'
in
strategy_name_list
assert
'[R, R, R, R]_
15
'
in
strategy_name_list
assert
'[R, R, R, S0]_
1
6'
in
strategy_name_list
assert
'[R, R, R, S1]_
17
'
in
strategy_name_list
assert
'[R, R, R, R]_
18
'
in
strategy_name_list
assert
'[R, S01, R, R]_1
9
'
in
strategy_name_list
assert
'[R, R, S01, R]_2
0
'
in
strategy_name_list
assert
'[R, R, R, R]_
21
'
in
strategy_name_list
assert
'[R, R, R, S01]_
22
'
in
strategy_name_list
assert
'[R, R, R, S1]_
11
'
in
strategy_name_list
assert
'[R, S0, R, S1]_1
2
'
in
strategy_name_list
assert
'[R, R, S0, S1]_
13
'
in
strategy_name_list
assert
'[R, R, R, S0]_
14
'
in
strategy_name_list
assert
'[R, S1, R, S0]_
15
'
in
strategy_name_list
assert
'[R, R, S1, S0]_
16
'
in
strategy_name_list
assert
'[R, R, R, R]_
17
'
in
strategy_name_list
assert
'[R, S0, R, R]_
18
'
in
strategy_name_list
assert
'[R, R, S0, R]_
19
'
in
strategy_name_list
assert
'[R, R, R, R]_
20
'
in
strategy_name_list
assert
'[R, S1, R, R]_
2
1'
in
strategy_name_list
assert
'[R, R, S1, R]_
22
'
in
strategy_name_list
assert
'[R, R, R, S1]_1
0
'
in
strategy_name_list
assert
'[R, R, R, S0]_
9
'
in
strategy_name_list
assert
'[R, R, R, R]_
8
'
in
strategy_name_list
assert
'[R, R, R, R]_
7
'
in
strategy_name_list
assert
'[R, R, R, S0]_6'
in
strategy_name_list
assert
'[R, R, R, S1]_
5
'
in
strategy_name_list
assert
'[R, R, R, R]_
0
'
in
strategy_name_list
assert
'[R, S01, R, R]_1'
in
strategy_name_list
assert
'[R, R, S01, R]_2'
in
strategy_name_list
assert
'[R, R, R, R]_
3
'
in
strategy_name_list
assert
'[R, R, R, S01]_
4
'
in
strategy_name_list
if
split_dim
==
1
:
assert
'[S0, R, R, S1]_0'
in
strategy_name_list
assert
'[R, R, R, S1]_1'
in
strategy_name_list
assert
'[R, R, S0, S1]_2'
in
strategy_name_list
assert
'[S1, R, R, S0]_3'
in
strategy_name_list
assert
'[R, R, R, S0]_4'
in
strategy_name_list
assert
'[R, R, S1, S0]_5'
in
strategy_name_list
assert
'[S0, R, R, R]_6'
in
strategy_name_list
assert
'[R, R, R, R]_7'
in
strategy_name_list
assert
'[R, R, S0, R]_8'
in
strategy_name_list
assert
'[S1, R, R, R]_9'
in
strategy_name_list
assert
'[R, R, R, R]_10'
in
strategy_name_list
assert
'[R, R, S1, R]_11'
in
strategy_name_list
assert
'[S0, R, R, S1]_11'
in
strategy_name_list
assert
'[R, R, R, S1]_12'
in
strategy_name_list
assert
'[R, R,
R
, S
0
]_13'
in
strategy_name_list
assert
'[
R
, R, R,
R
]_14'
in
strategy_name_list
assert
'[R, R, R,
R
]_15'
in
strategy_name_list
assert
'[R, R,
R
, S0]_16'
in
strategy_name_list
assert
'[
R
, R, R,
S1
]_17'
in
strategy_name_list
assert
'[
S01
, R, R, R]_18'
in
strategy_name_list
assert
'[R, R,
R
, R]_19'
in
strategy_name_list
assert
'[
R
, R,
S01
, R]_20'
in
strategy_name_list
assert
'[R, R,
S0
, S
1
]_13'
in
strategy_name_list
assert
'[
S1
, R, R,
S0
]_14'
in
strategy_name_list
assert
'[R, R, R,
S0
]_15'
in
strategy_name_list
assert
'[R, R,
S1
, S0]_16'
in
strategy_name_list
assert
'[
S0
, R, R,
R
]_17'
in
strategy_name_list
assert
'[
R
, R, R, R]_18'
in
strategy_name_list
assert
'[R, R,
S0
, R]_19'
in
strategy_name_list
assert
'[
S1
, R,
R
, R]_20'
in
strategy_name_list
assert
'[R, R, R, R]_21'
in
strategy_name_list
assert
'[R, R, R, S01]_22'
in
strategy_name_list
assert
'[R, R, S1, R]_22'
in
strategy_name_list
assert
'[R, R, R, S1]_10'
in
strategy_name_list
assert
'[R, R, R, S0]_9'
in
strategy_name_list
assert
'[R, R, R, R]_8'
in
strategy_name_list
assert
'[R, R, R, R]_7'
in
strategy_name_list
assert
'[R, R, R, S0]_6'
in
strategy_name_list
assert
'[R, R, R, S1]_5'
in
strategy_name_list
assert
'[S01, R, R, R]_0'
in
strategy_name_list
assert
'[R, R, R, R]_1'
in
strategy_name_list
assert
'[R, R, S01, R]_2'
in
strategy_name_list
assert
'[R, R, R, R]_3'
in
strategy_name_list
assert
'[R, R, R, S01]_4'
in
strategy_name_list
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
View file @
93b788b9
...
...
@@ -196,54 +196,57 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
if
model_cls
.
__name__
==
'LinearViewModel'
:
if
tgt_shape
==
(
32
,
4
,
64
,
16
,
4
):
assert
'[S0, R, R, S1] -> [S0, R, R, S1, R]_0'
in
strategy_name_list
assert
'[R, S0, R, S1] -> FULLY REPLICATED_1'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [R, R, S0, S1, R]_2'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [S1, R, R, S0, R]_3'
in
strategy_name_list
assert
'[R, S1, R, S0] -> FULLY REPLICATED_4'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [R, R, S1, S0, R]_5'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> FULLY REPLICATED_7'
in
strategy_name_list
assert
'[R, R, S0, R] -> [R, R, S0, R, R]_8'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R, R]_9'
in
strategy_name_list
assert
'[R, S1, R, R] -> FULLY REPLICATED_10'
in
strategy_name_list
assert
'[R, R, S1, R] -> [R, R, S1, R, R]_11'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1, R]_12'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0, R]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R]_14'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R]_15'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0, R]_16'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1, R]_17'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R, R]_18'
in
strategy_name_list
assert
'[R, S01, R, R] -> FULLY REPLICATED_19'
in
strategy_name_list
assert
'[R, R, S01, R] -> [R, R, S01, R, R]_20'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R]_21'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, R, R, S01, R]_22'
in
strategy_name_list
for
strategy
in
strategy_name_list
:
print
(
strategy
)
# print(strategy_name_list)
assert
'[S0, R, R, S1] -> [S0, R, R, S1, R]_11'
in
strategy_name_list
assert
'[R, S0, R, S1] -> FULLY REPLICATED_12'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [R, R, S0, S1, R]_13'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [S1, R, R, S0, R]_14'
in
strategy_name_list
assert
'[R, S1, R, S0] -> FULLY REPLICATED_15'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [R, R, S1, S0, R]_16'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R, R]_17'
in
strategy_name_list
assert
'[R, S0, R, R] -> FULLY REPLICATED_18'
in
strategy_name_list
assert
'[R, R, S0, R] -> [R, R, S0, R, R]_19'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R, R]_20'
in
strategy_name_list
assert
'[R, S1, R, R] -> FULLY REPLICATED_21'
in
strategy_name_list
assert
'[R, R, S1, R] -> [R, R, S1, R, R]_22'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1, R]_10'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0, R]_9'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R]_8'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R]_7'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0, R]_6'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1, R]_5'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R, R]_0'
in
strategy_name_list
assert
'[R, S01, R, R] -> FULLY REPLICATED_1'
in
strategy_name_list
assert
'[R, R, S01, R] -> [R, R, S01, R, R]_2'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R]_3'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, R, R, S01, R]_4'
in
strategy_name_list
if
tgt_shape
==
(
8
,
4
,
4
,
64
,
16
,
4
):
assert
'[S0, R, R, S1] -> [S0, R, R, R, S1, R]_
0
'
in
strategy_name_list
assert
'[R, S0, R, S1] -> [R, S0, R, R, S1, R]_1'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [R, R, R, S0, S1, R]_
2
'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [S1, R, R, R, S0, R]_
3
'
in
strategy_name_list
assert
'[R, S1, R, S0] -> [R, S1, R, R, S0, R]_
4
'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [R, R, R, S1, S0, R]_
5
'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R, R, R]_
6
'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, S0, R, R, R, R]_
7
'
in
strategy_name_list
assert
'[R, R, S0, R] -> [R, R, R, S0, R, R]_
8
'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R, R, R]_
9
'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, S1, R, R, R, R]_1
0
'
in
strategy_name_list
assert
'[R, R, S1, R] -> [R, R, R, S1, R, R]_
11
'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, R, S1, R]_1
2
'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, R, S0, R]_
13
'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_
14
'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_
15
'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, R, S0, R]_
1
6'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, R, S1, R]_
17
'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R, R, R]_
18
'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, S01, R, R, R, R]_1
9
'
in
strategy_name_list
assert
'[R, R, S01, R] -> [R, R, R, S01, R, R]_2
0
'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_
21
'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, R, R, R, S01, R]_
22
'
in
strategy_name_list
assert
'[S0, R, R, S1] -> [S0, R, R, R, S1, R]_
11
'
in
strategy_name_list
assert
'[R, S0, R, S1] -> [R, S0, R, R, S1, R]_1
2
'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [R, R, R, S0, S1, R]_
13
'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [S1, R, R, R, S0, R]_
14
'
in
strategy_name_list
assert
'[R, S1, R, S0] -> [R, S1, R, R, S0, R]_
15
'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [R, R, R, S1, S0, R]_
16
'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R, R, R]_
17
'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, S0, R, R, R, R]_
18
'
in
strategy_name_list
assert
'[R, R, S0, R] -> [R, R, R, S0, R, R]_
19
'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R, R, R]_
20
'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, S1, R, R, R, R]_
2
1'
in
strategy_name_list
assert
'[R, R, S1, R] -> [R, R, R, S1, R, R]_
22
'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, R, S1, R]_1
0
'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, R, S0, R]_
9
'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_
8
'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_
7
'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, R, S0, R]_6'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, R, S1, R]_
5
'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R, R, R]_
0
'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, S01, R, R, R, R]_1'
in
strategy_name_list
assert
'[R, R, S01, R] -> [R, R, R, S01, R, R]_2'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_
3
'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, R, R, R, S01, R]_
4
'
in
strategy_name_list
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
View file @
93b788b9
...
...
@@ -6,7 +6,8 @@ from torch.fx import GraphModule
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.solver
import
SolverOptions
,
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.solver
import
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.solver.cost_graph
import
CostGraph
from
colossalai.auto_parallel.tensor_shard.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.solver.solver
import
Solver
...
...
tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py
View file @
93b788b9
import
torch
from
colossalai.auto_parallel.tensor_shard.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationDataType
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.auto_parallel.tensor_shard.solver
import
CostGraph
,
GraphAnalyser
,
Solver
,
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
...
...
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
View file @
93b788b9
...
...
@@ -3,13 +3,8 @@ from torch.fx import GraphModule
from
torchvision.models
import
resnet50
from
colossalai.auto_parallel.tensor_shard.constants
import
BATCHNORM_MODULE_OP
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.auto_parallel.tensor_shard.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.solver
import
CostGraph
,
GraphAnalyser
,
Solver
,
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
...
...
version.txt
View file @
93b788b9
0.2.
3
0.2.
5
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment