Unverified Commit 6c331a5a authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] refactored the autoparallel module for organization (#1706)

* [autoparallel] refactored the autoparallel module for organization

* polish code
parent 91cd34e6
from .batch_norm_handler import BatchNormModuleHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .dot_handler import LinearFunctionHandler, LinearModuleHandler
from .layer_norm_handler import LayerNormModuleHandler
from .batch_norm_handler import BatchNormModuleHandler
from .conv_handler import ConvModuleHandler, ConvFunctionHandler
from .where_handler import WhereHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .reshape_handler import ReshapeHandler
from .placeholder_handler import PlacehodlerHandler
from .output_handler import OuputHandler
from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OuputHandler
from .placeholder_handler import PlacehodlerHandler
from .registry import operator_registry
from .reshape_handler import ReshapeHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .where_handler import WhereHandler
__all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler',
'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler',
'OuputHandler', 'WhereHandler', 'NormPoolingHandler'
'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry'
]
from typing import Dict, List
import torch
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import BatchNormStrategyGenerator, StrategyGenerator
from typing import List, Dict
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler
from .registry import operator_registry
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
__all__ = ['BatchNormModuleHandler']
......
from typing import Dict, List
import torch
import torch.nn.functional as F
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import ConvStrategyGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
__all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
@operator_registry.register(torch.nn.Conv1d)
......
from copy import deepcopy
from typing import Dict, List, Union
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim)
from colossalai.tensor.sharding_spec import ShardingException
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator, BatchedMatMulStrategyGenerator
from typing import List, Dict, Union
from .registry import operator_registry
from copy import deepcopy
from .utils import switch_partition_dim, update_partition_dim
from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyGenerator, StrategyGenerator)
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
......
import operator
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import TensorStrategyGenerator, TensorTupleStrategyGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry
import operator
from .strategy import (StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator)
__all__ = ['GetItemHandler']
......
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import LayerNormGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry
from .strategy import LayerNormGenerator, StrategyGenerator
__all__ = ['LayerNormModuleHandler']
......
from abc import ABC, abstractmethod
from typing import Dict, List, Union
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, ShardingStrategy, StrategiesVector,
TrainCycleItem)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from typing import Dict, List, Union
from ..sharding_strategy import ShardingStrategy, StrategiesVector, OperationData, TrainCycleItem
from ..strategy import StrategyGenerator
from .._utils import generate_resharding_costs
from .strategy import StrategyGenerator
class NodeHandler(ABC):
......
from typing import Dict, List
import torch
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import NormalPoolStrategyGenerator, StrategyGenerator
from typing import List, Dict
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler
from .registry import operator_registry
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
__all__ = ['NormPoolingHandler']
......
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from colossalai.auto_parallel.solver.strategy import StrategyGenerator
from colossalai.auto_parallel.solver.strategy.output_generator import OutputGenerator
from typing import List, Dict
from .registry import operator_registry
from .strategy import OutputGenerator, StrategyGenerator
__all__ = ['OuputHandler']
......
import torch
from typing import Dict, List
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from colossalai.auto_parallel.solver.strategy import StrategyGenerator
from colossalai.auto_parallel.solver.strategy.placeholder_generator import PlaceholderGenerator
from typing import List, Dict
from .registry import operator_registry
from .strategy import PlaceholderGenerator, StrategyGenerator
__all__ = ['PlacehodlerHandler']
......
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import ReshapeGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry
import operator
from .strategy import ReshapeGenerator, StrategyGenerator
__all__ = ['ReshapeHandler']
......
from .strategy_generator import StrategyGenerator
from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator
from .batch_norm_generator import BatchNormStrategyGenerator
from .unary_elementwise_generator import UnaryElementwiseGenerator
from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator
from .getitem_generator import (GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator)
from .layer_norm_generator import LayerNormGenerator
from .where_generator import WhereGenerator
from .reshape_generator import ReshapeGenerator
from .matmul_strategy_generator import (BatchedMatMulStrategyGenerator, DotProductStrategyGenerator,
LinearProjectionStrategyGenerator, MatVecStrategyGenerator)
from .normal_pooling_generator import NormalPoolStrategyGenerator
from .placeholder_generator import PlaceholderGenerator
from .output_generator import OutputGenerator
from .placeholder_generator import PlaceholderGenerator
from .reshape_generator import ReshapeGenerator
from .strategy_generator import StrategyGenerator
from .unary_elementwise_generator import UnaryElementwiseGenerator
from .where_generator import WhereGenerator
__all__ = [
'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',
......
import copy
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler
import copy
__all__ = ['BatchNormStrategyGenerator']
......
import copy
import operator
import warnings
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import exception_handler
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler
import warnings
import copy
class ConvStrategyGenerator(StrategyGenerator):
......
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
import copy
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
from typing import List
from .._utils import exception_handler
import copy
__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator']
......
import copy
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy
__all__ = ['LayerNormGenerator']
......
from audioop import bias
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List
class MatMulStrategyGenerator(StrategyGenerator):
......
import copy
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from .strategy_generator import StrategyGenerator
class NormalPoolStrategyGenerator(StrategyGenerator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment