Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
11ec070e
Unverified
Commit
11ec070e
authored
Sep 29, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 29, 2022
Browse files
[hotfix]unit test (#1670)
parent
a60024e7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
1 deletion
+5
-1
colossalai/auto_parallel/solver/strategy/__init__.py
colossalai/auto_parallel/solver/strategy/__init__.py
+3
-1
tests/test_auto_parallel/test_node_handler/test_bmm_handler.py
.../test_auto_parallel/test_node_handler/test_bmm_handler.py
+2
-0
No files found.
colossalai/auto_parallel/solver/strategy/__init__.py
View file @
11ec070e
from
.strategy_generator
import
StrategyGenerator_V2
from
.matmul_strategy_generator
import
DotProductStrategyGenerator
,
MatVecStrategyGenerator
,
LinearProjectionStrategyGenerator
,
BatchedMatMulStrategyGenerator
from
.conv_strategy_generator
import
ConvStrategyGenerator
from
.batch_norm_generator
import
BatchNormStrategyGenerator
__all__
=
[
'StrategyGenerator_V2'
,
'DotProductStrategyGenerator'
,
'MatVecStrategyGenerator'
,
'LinearProjectionStrategyGenerator'
,
'BatchedMatMulStrategyGenerator'
,
'ConvStrategyGenerator'
'LinearProjectionStrategyGenerator'
,
'BatchedMatMulStrategyGenerator'
,
'ConvStrategyGenerator'
,
'BatchNormStrategyGenerator'
]
tests/test_auto_parallel/test_node_handler/test_bmm_handler.py
View file @
11ec070e
...
...
@@ -19,6 +19,7 @@ class BMMTorchFunctionModule(nn.Module):
return
torch
.
bmm
(
x1
,
x2
)
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
parametrize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
def
test_2d_device_mesh
(
module
):
...
...
@@ -89,6 +90,7 @@ def test_2d_device_mesh(module):
assert
'Sb1R = Sb1Sk0 x Sb1Sk0'
in
strategy_name_list
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
parametrize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
def
test_1d_device_mesh
(
module
):
model
=
module
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment