Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d39e11df
Unverified
Commit
d39e11df
authored
Aug 24, 2022
by
Frank Lee
Committed by
GitHub
Aug 24, 2022
Browse files
[autoparallel] added namespace constraints (#1490)
parent
a6c87491
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
0 deletions
+10
-0
colossalai/auto_parallel/solver/conv_handler.py
colossalai/auto_parallel/solver/conv_handler.py
+2
-0
colossalai/auto_parallel/solver/dot_handler.py
colossalai/auto_parallel/solver/dot_handler.py
+2
-0
colossalai/auto_parallel/solver/operator_handler.py
colossalai/auto_parallel/solver/operator_handler.py
+6
-0
No files found.
colossalai/auto_parallel/solver/conv_handler.py
View file @
d39e11df
...
@@ -4,6 +4,8 @@ import torch
...
@@ -4,6 +4,8 @@ import torch
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
.operator_handler
import
OperatorHandler
from
.operator_handler
import
OperatorHandler
__all__
=
[
'ConvHandler'
]
class
ConvHandler
(
OperatorHandler
):
class
ConvHandler
(
OperatorHandler
):
"""
"""
...
...
colossalai/auto_parallel/solver/dot_handler.py
View file @
d39e11df
...
@@ -4,6 +4,8 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy,
...
@@ -4,6 +4,8 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy,
from
.operator_handler
import
OperatorHandler
from
.operator_handler
import
OperatorHandler
from
functools
import
reduce
from
functools
import
reduce
__all__
=
[
'DotHandler'
]
class
DotHandler
(
OperatorHandler
):
class
DotHandler
(
OperatorHandler
):
"""
"""
...
...
colossalai/auto_parallel/solver/operator_handler.py
View file @
d39e11df
from
webbrowser
import
Opera
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
...
@@ -9,6 +10,8 @@ from colossalai.tensor.sharding_spec import ShardingSpec
...
@@ -9,6 +10,8 @@ from colossalai.tensor.sharding_spec import ShardingSpec
from
.sharding_strategy
import
StrategiesVector
from
.sharding_strategy
import
StrategiesVector
__all__
=
[
'OperatorHandler'
]
class
OperatorHandler
(
ABC
):
class
OperatorHandler
(
ABC
):
'''
'''
...
@@ -48,6 +51,9 @@ class OperatorHandler(ABC):
...
@@ -48,6 +51,9 @@ class OperatorHandler(ABC):
@
abstractmethod
@
abstractmethod
def
register_strategy
(
self
)
->
StrategiesVector
:
def
register_strategy
(
self
)
->
StrategiesVector
:
"""
Register
"""
pass
pass
def
_generate_sharding_spec
(
self
,
tensor
:
torch
.
Tensor
,
dim_partition_dict
:
Dict
[
int
,
int
])
->
ShardingSpec
:
def
_generate_sharding_spec
(
self
,
tensor
:
torch
.
Tensor
,
dim_partition_dict
:
Dict
[
int
,
int
])
->
ShardingSpec
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment