Commit 93b788b9 authored by binmakeswell's avatar binmakeswell
Browse files

Merge branch 'main' into fix/format

parents 2fd528b9 1dc003c1
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List
from colossalai.device.device_mesh import DeviceMesh
__all__ = ['IntermediateStrategy', 'StrategyGenerator']
@dataclass
class IntermediateStrategy:
"""
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
Args:
name (str): name of the sharding strategy.
dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping.
all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation.
"""
name: str
dim_partition_dict: Dict[str, Dict[int, List[int]]]
all_reduce_axis: List[int] = None
class StrategyGenerator(ABC):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
"""
def __init__(self, device_mesh: DeviceMesh):
self.device_mesh = device_mesh
@abstractmethod
def generate(self) -> List[IntermediateStrategy]:
"""
"""
pass
@abstractmethod
def validate(self, *args, **kwargs) -> bool:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass
from dataclasses import dataclass
__all__ = ['SolverOptions']
@dataclass
class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
fast: bool = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment