"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "8993c8a8170ac116a551840e3d442af78bedc53e"
Unverified Commit d39e11df authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] added namespace constraints (#1490)

parent a6c87491
...@@ -4,6 +4,8 @@ import torch ...@@ -4,6 +4,8 @@ import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
__all__ = ['ConvHandler']
class ConvHandler(OperatorHandler): class ConvHandler(OperatorHandler):
""" """
......
...@@ -4,6 +4,8 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, ...@@ -4,6 +4,8 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy,
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
from functools import reduce from functools import reduce
__all__ = ['DotHandler']
class DotHandler(OperatorHandler): class DotHandler(OperatorHandler):
""" """
......
from webbrowser import Opera
import torch import torch
import torch.nn as nn import torch.nn as nn
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -9,6 +10,8 @@ from colossalai.tensor.sharding_spec import ShardingSpec ...@@ -9,6 +10,8 @@ from colossalai.tensor.sharding_spec import ShardingSpec
from .sharding_strategy import StrategiesVector from .sharding_strategy import StrategiesVector
__all__ = ['OperatorHandler']
class OperatorHandler(ABC): class OperatorHandler(ABC):
''' '''
...@@ -48,6 +51,9 @@ class OperatorHandler(ABC): ...@@ -48,6 +51,9 @@ class OperatorHandler(ABC):
@abstractmethod @abstractmethod
def register_strategy(self) -> StrategiesVector: def register_strategy(self) -> StrategiesVector:
"""
Register
"""
pass pass
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec: def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment