Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
...@@ -3,7 +3,6 @@ from typing import Dict, List, Tuple ...@@ -3,7 +3,6 @@ from typing import Dict, List, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.fx import GraphModule
from torch.fx.graph import Graph from torch.fx.graph import Graph
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
...@@ -14,27 +13,32 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pas ...@@ -14,27 +13,32 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pas
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
class ModuleWrapper(nn.Module): class ModuleWrapper(nn.Module):
''' """
This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict
into the forward function. into the forward function.
''' """
def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], def __init__(
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]): self,
''' module: ColoGraphModule,
sharding_spec_dict: Dict[int, List[ShardingSpec]],
origin_spec_dict: Dict[int, ShardingSpec],
comm_actions_dict: Dict[int, Dict[str, CommAction]],
):
"""
Args: Args:
module: the original module module: the original module
sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node. sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node.
origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor. origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor.
comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor. comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor.
''' """
super(ModuleWrapper, self).__init__() super(ModuleWrapper, self).__init__()
self.module = module self.module = module
self.sharding_spec_dict = sharding_spec_dict self.sharding_spec_dict = sharding_spec_dict
...@@ -42,67 +46,68 @@ class ModuleWrapper(nn.Module): ...@@ -42,67 +46,68 @@ class ModuleWrapper(nn.Module):
self.comm_actions_dict = comm_actions_dict self.comm_actions_dict = comm_actions_dict
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.module(*args, return self.module(
sharding_spec_convert_dict=self.sharding_spec_dict, *args,
origin_node_sharding_spec_dict=self.origin_spec_dict, sharding_spec_convert_dict=self.sharding_spec_dict,
comm_actions_dict=self.comm_actions_dict, origin_node_sharding_spec_dict=self.origin_spec_dict,
**kwargs) comm_actions_dict=self.comm_actions_dict,
**kwargs,
)
def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable): def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable):
''' """
This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func. This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func.
''' """
# TODO: implement this function # TODO: implement this function
pass
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]): def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
''' """
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
from the alpha_beta_dict. These two values will be used to estimate the communication cost. from the alpha_beta_dict. These two values will be used to estimate the communication cost.
''' """
# TODO: implement this function # TODO: implement this function
pass
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, def build_strategy_constructor(
shard_option: str): graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, shard_option: str
''' ):
"""
This method is used to build the strategy_constructor for the given graph. This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler. is constructed by the related node handler.
''' """
if solver_preference == 'standard': if solver_preference == "standard":
solver_preference = SolverPerference.STANDARD solver_preference = SolverPerference.STANDARD
elif solver_preference == 'tp': elif solver_preference == "tp":
solver_preference = SolverPerference.TP solver_preference = SolverPerference.TP
elif solver_preference == 'dp': elif solver_preference == "dp":
solver_preference = SolverPerference.DP solver_preference = SolverPerference.DP
else: else:
raise ValueError(f'Invalid solver_preference: {solver_preference}') raise ValueError(f"Invalid solver_preference: {solver_preference}")
if dataloader_option == 'replicated': if dataloader_option == "replicated":
dataloader_option = DataloaderOption.REPLICATED dataloader_option = DataloaderOption.REPLICATED
elif dataloader_option == 'distributed': elif dataloader_option == "distributed":
dataloader_option = DataloaderOption.DISTRIBUTED dataloader_option = DataloaderOption.DISTRIBUTED
else: else:
raise ValueError(f'Invalid dataloader_option: {dataloader_option}') raise ValueError(f"Invalid dataloader_option: {dataloader_option}")
if shard_option == 'standard': if shard_option == "standard":
shard_option = ShardOption.STANDARD shard_option = ShardOption.STANDARD
elif shard_option == 'shard': elif shard_option == "shard":
shard_option = ShardOption.SHARD shard_option = ShardOption.SHARD
elif shard_option == 'shard_last_axis': elif shard_option == "shard_last_axis":
shard_option = ShardOption.SHARD_LAST_AXIS shard_option = ShardOption.SHARD_LAST_AXIS
elif shard_option == 'full_shard': elif shard_option == "full_shard":
shard_option = ShardOption.FULL_SHARD shard_option = ShardOption.FULL_SHARD
else: else:
raise ValueError(f'Invalid shard_option: {shard_option}') raise ValueError(f"Invalid shard_option: {shard_option}")
solver_options = SolverOptions(solver_perference=solver_preference, solver_options = SolverOptions(
dataloader_option=dataloader_option, solver_perference=solver_preference, dataloader_option=dataloader_option, shard_option=shard_option
shard_option=shard_option) )
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost() strategies_constructor.build_strategies_and_cost()
...@@ -110,10 +115,10 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_pre ...@@ -110,10 +115,10 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_pre
def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
''' """
This method is used to solve the best solution for the given graph. This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node. The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
''' """
# temporarily we use all nodes as liveness list, we count the backward memory cost together with # temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase. # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm) # graph_analyser = GraphAnalyser(gm)
...@@ -127,23 +132,23 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc ...@@ -127,23 +132,23 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
return solution return solution
def transform_to_sharded_model(gm: ColoGraphModule, def transform_to_sharded_model(
meta_args: Dict, gm: ColoGraphModule,
solution: List[int], meta_args: Dict,
device_mesh: DeviceMesh, solution: List[int],
strategies_constructor: StrategiesConstructor, device_mesh: DeviceMesh,
overlap: bool = False): strategies_constructor: StrategiesConstructor,
''' overlap: bool = False,
):
"""
This method is used to transform the original graph to the sharded graph. This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass. will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass. The communication node will be added into the graph using the runtime_apply_pass.
''' """
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
solution, gm, solution, device_mesh, strategies_constructor, overlap=overlap
device_mesh, )
strategies_constructor,
overlap=overlap)
gm = runtime_apply_pass(gm) gm = runtime_apply_pass(gm)
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict) shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
gm.recompile() gm.recompile()
...@@ -152,12 +157,14 @@ def transform_to_sharded_model(gm: ColoGraphModule, ...@@ -152,12 +157,14 @@ def transform_to_sharded_model(gm: ColoGraphModule,
return gm, sharding_spec_dicts return gm, sharding_spec_dicts
def initialize_device_mesh(world_size: int = -1, def initialize_device_mesh(
physical_devices: List[int] = None, world_size: int = -1,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, physical_devices: List[int] = None,
logical_mesh_shape: Tuple[int] = None, alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_id: torch.Tensor = None): logical_mesh_shape: Tuple[int] = None,
''' logical_mesh_id: torch.Tensor = None,
):
"""
This method is used to initialize the device mesh. This method is used to initialize the device mesh.
Args: Args:
...@@ -170,7 +177,7 @@ def initialize_device_mesh(world_size: int = -1, ...@@ -170,7 +177,7 @@ def initialize_device_mesh(world_size: int = -1,
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. mesh shape.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id. logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
''' """
# if world_size is not set, use the world size from torch.distributed # if world_size is not set, use the world size from torch.distributed
if world_size == -1: if world_size == -1:
world_size = dist.get_world_size() world_size = dist.get_world_size()
...@@ -201,27 +208,31 @@ def initialize_device_mesh(world_size: int = -1, ...@@ -201,27 +208,31 @@ def initialize_device_mesh(world_size: int = -1,
# extract alpha and beta values for the chosen logical mesh shape # extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id) mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, device_mesh = DeviceMesh(
logical_mesh_id=logical_mesh_id, physical_mesh_id=physical_mesh,
mesh_alpha=mesh_alpha, logical_mesh_id=logical_mesh_id,
mesh_beta=mesh_beta, mesh_alpha=mesh_alpha,
init_process_group=True) mesh_beta=mesh_beta,
init_process_group=True,
)
return device_mesh return device_mesh
def initialize_model(model: nn.Module, def initialize_model(
meta_args: Dict[str, torch.Tensor], model: nn.Module,
device_mesh: DeviceMesh, meta_args: Dict[str, torch.Tensor],
memory_budget: float = -1.0, device_mesh: DeviceMesh,
overlap: bool = False, memory_budget: float = -1.0,
solver_preference: str = 'standard', overlap: bool = False,
dataloader_option: str = 'replicated', solver_preference: str = "standard",
shard_option: str = 'standard', dataloader_option: str = "replicated",
save_solver_solution: bool = False, shard_option: str = "standard",
load_solver_solution: bool = False, save_solver_solution: bool = False,
solution_path: str = None, load_solver_solution: bool = False,
return_solution: bool = False): solution_path: str = None,
''' return_solution: bool = False,
):
"""
This method is used to initialize the sharded model which could be used as normal pytorch model. This method is used to initialize the sharded model which could be used as normal pytorch model.
Args: Args:
...@@ -246,7 +257,7 @@ def initialize_model(model: nn.Module, ...@@ -246,7 +257,7 @@ def initialize_model(model: nn.Module,
return_solution(optional): if the return_solution is True, the solution will be returned. The returned return_solution(optional): if the return_solution is True, the solution will be returned. The returned
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies. return a series of integers, but return the best strategies.
''' """
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True) tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
graph = tracer.trace(root=model, meta_args=meta_args) graph = tracer.trace(root=model, meta_args=meta_args)
...@@ -256,11 +267,13 @@ def initialize_model(model: nn.Module, ...@@ -256,11 +267,13 @@ def initialize_model(model: nn.Module,
shape_prop_pass(gm, *meta_args.values()) shape_prop_pass(gm, *meta_args.values())
gm.recompile() gm.recompile()
strategies_constructor = build_strategy_constructor(graph, strategies_constructor = build_strategy_constructor(
device_mesh, graph,
solver_preference=solver_preference, device_mesh,
dataloader_option=dataloader_option, solver_preference=solver_preference,
shard_option=shard_option) dataloader_option=dataloader_option,
shard_option=shard_option,
)
if load_solver_solution: if load_solver_solution:
solution = torch.load(solution_path) solution = torch.load(solution_path)
else: else:
...@@ -268,8 +281,9 @@ def initialize_model(model: nn.Module, ...@@ -268,8 +281,9 @@ def initialize_model(model: nn.Module,
if save_solver_solution: if save_solver_solution:
torch.save(solution, solution_path) torch.save(solution, solution_path)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor, gm, sharding_spec_dicts = transform_to_sharded_model(
overlap) gm, meta_args, solution, device_mesh, strategies_constructor, overlap
)
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts) model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
...@@ -277,28 +291,30 @@ def initialize_model(model: nn.Module, ...@@ -277,28 +291,30 @@ def initialize_model(model: nn.Module,
solution_to_return = [] solution_to_return = []
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes): for index, node in enumerate(nodes):
solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}') solution_to_return.append(f"{node.name} {node.strategies_vector[solution[index]].name}")
return model_to_return, solution_to_return return model_to_return, solution_to_return
else: else:
return model_to_return return model_to_return
def autoparallelize(model: nn.Module, def autoparallelize(
meta_args: Dict[str, torch.Tensor] = None, model: nn.Module,
data_loader: torch.utils.data.DataLoader = None, meta_args: Dict[str, torch.Tensor] = None,
data_process_func: callable = None, data_loader: torch.utils.data.DataLoader = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, data_process_func: callable = None,
logical_mesh_shape: Tuple[int] = None, alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_id: torch.Tensor = None, logical_mesh_shape: Tuple[int] = None,
solver_preference: str = 'standard', logical_mesh_id: torch.Tensor = None,
dataloader_option: str = 'replicated', solver_preference: str = "standard",
shard_option: str = 'standard', dataloader_option: str = "replicated",
save_solver_solution: bool = False, shard_option: str = "standard",
load_solver_solution: bool = False, save_solver_solution: bool = False,
solver_solution_path: str = None, load_solver_solution: bool = False,
return_solution: bool = False, solver_solution_path: str = None,
memory_budget: float = -1.0): return_solution: bool = False,
''' memory_budget: float = -1.0,
):
"""
This method is used to initialize the device mesh, extract the meta_args, and This method is used to initialize the device mesh, extract the meta_args, and
use them to create a sharded model. use them to create a sharded model.
...@@ -329,24 +345,26 @@ def autoparallelize(model: nn.Module, ...@@ -329,24 +345,26 @@ def autoparallelize(model: nn.Module,
return_solution(optional): if the return_solution is True, the solution will be returned. return_solution(optional): if the return_solution is True, the solution will be returned.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity. the memory budget will be infinity.
''' """
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, device_mesh = initialize_device_mesh(
logical_mesh_shape=logical_mesh_shape, alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape, logical_mesh_id=logical_mesh_id
logical_mesh_id=logical_mesh_id) )
if meta_args is None: if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func) meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
rst_to_unpack = initialize_model(model, rst_to_unpack = initialize_model(
meta_args, model,
device_mesh, meta_args,
solver_preference=solver_preference, device_mesh,
dataloader_option=dataloader_option, solver_preference=solver_preference,
shard_option=shard_option, dataloader_option=dataloader_option,
save_solver_solution=save_solver_solution, shard_option=shard_option,
load_solver_solution=load_solver_solution, save_solver_solution=save_solver_solution,
solution_path=solver_solution_path, load_solver_solution=load_solver_solution,
return_solution=return_solution, solution_path=solver_solution_path,
memory_budget=memory_budget) return_solution=return_solution,
memory_budget=memory_budget,
)
if return_solution: if return_solution:
model, solution = rst_to_unpack model, solution = rst_to_unpack
......
...@@ -25,11 +25,33 @@ from .view_handler import ViewHandler ...@@ -25,11 +25,33 @@ from .view_handler import ViewHandler
from .where_handler import WhereHandler from .where_handler import WhereHandler
__all__ = [ __all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', "LinearFunctionHandler",
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', "LinearModuleHandler",
'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', "BMMFunctionHandler",
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', "AddBMMFunctionHandler",
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', "LayerNormModuleHandler",
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler', "BatchNormModuleHandler",
'SplitHandler' "ConvModuleHandler",
"ConvFunctionHandler",
"UnaryElementwiseHandler",
"DefaultReshapeHandler",
"PlaceholderHandler",
"OutputHandler",
"WhereHandler",
"NormPoolingHandler",
"BinaryElementwiseHandler",
"MatMulHandler",
"operator_registry",
"ADDMMFunctionHandler",
"GetItemHandler",
"GetattrHandler",
"ViewHandler",
"PermuteHandler",
"TensorConstructorHandler",
"EmbeddingModuleHandler",
"EmbeddingFunctionHandler",
"SumHandler",
"SoftmaxHandler",
"TransposeHandler",
"SplitHandler",
] ]
...@@ -2,15 +2,13 @@ from typing import Dict, List, Union ...@@ -2,15 +2,13 @@ from typing import Dict, List, Union
import torch import torch
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
__all__ = ['ADDMMFunctionHandler'] __all__ = ["ADDMMFunctionHandler"]
@operator_registry.register(torch.addmm) @operator_registry.register(torch.addmm)
...@@ -30,25 +28,26 @@ class ADDMMFunctionHandler(NodeHandler): ...@@ -30,25 +28,26 @@ class ADDMMFunctionHandler(NodeHandler):
return data_type return data_type
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# input operand # input operand
input_data = self.node.args[1]._meta_data input_data = self.node.args[1]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[1]), physical_input_operand = OperationData(
type=self._infer_op_data_type(input_data), name=str(self.node.args[1]), type=self._infer_op_data_type(input_data), data=input_data
data=input_data) )
# other operand # other operand
other_data = self.node.args[2]._meta_data other_data = self.node.args[2]._meta_data
physical_other_operand = OperationData(name=str(self.node.args[2]), physical_other_operand = OperationData(
type=self._infer_op_data_type(other_data), name=str(self.node.args[2]), type=self._infer_op_data_type(other_data), data=other_data
data=other_data) )
# bias physical shape # bias physical shape
bias_logical_shape = self.node._meta_data.shape bias_logical_shape = self.node._meta_data.shape
bias_data = self.node.args[0]._meta_data bias_data = self.node.args[0]._meta_data
physical_bias_operand = OperationData(name=str(self.node.args[0]), physical_bias_operand = OperationData(
type=self._infer_op_data_type(bias_data), name=str(self.node.args[0]),
data=bias_data, type=self._infer_op_data_type(bias_data),
logical_shape=bias_logical_shape) data=bias_data,
logical_shape=bias_logical_shape,
)
# output # output
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
...@@ -57,7 +56,7 @@ class ADDMMFunctionHandler(NodeHandler): ...@@ -57,7 +56,7 @@ class ADDMMFunctionHandler(NodeHandler):
"input": physical_input_operand, "input": physical_input_operand,
"other": physical_other_operand, "other": physical_other_operand,
"output": physical_output, "output": physical_output,
'bias': physical_bias_operand "bias": physical_bias_operand,
} }
return mapping return mapping
...@@ -66,26 +65,27 @@ class ADDMMFunctionHandler(NodeHandler): ...@@ -66,26 +65,27 @@ class ADDMMFunctionHandler(NodeHandler):
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append( generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm')) LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="addmm")
)
return generators return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# convert bias from its logical sharding spec to its physical sharding spec # convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
bias_op_data = op_data_mapping['bias'] bias_op_data = op_data_mapping["bias"]
bias_physical_shape = bias_op_data.data.shape bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
bias_sharding_spec, bias_logical_shape, bias_physical_shape) bias_sharding_spec, bias_logical_shape, bias_physical_shape
)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0: if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node, comm_action = comm_actions_for_oprands(
removed_dims=removed_dims, node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec
op_data=bias_op_data, )
sharding_spec=bias_sharding_spec)
strategy.communication_actions[bias_op_data] = comm_action strategy.communication_actions[bias_op_data] = comm_action
return strategy return strategy
...@@ -2,12 +2,12 @@ from typing import Dict, List ...@@ -2,12 +2,12 @@ from typing import Dict, List
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoModuleHandler, ModuleHandler from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import BatchNormStrategyGenerator, StrategyGenerator from .strategy import BatchNormStrategyGenerator, StrategyGenerator
__all__ = ['BatchNormModuleHandler'] __all__ = ["BatchNormModuleHandler"]
@operator_registry.register(torch.nn.BatchNorm1d) @operator_registry.register(torch.nn.BatchNorm1d)
...@@ -27,30 +27,37 @@ class BatchNormModuleHandler(MetaInfoModuleHandler): ...@@ -27,30 +27,37 @@ class BatchNormModuleHandler(MetaInfoModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies # use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process # the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
data=self.node.args[0]._meta_data) )
physical_other_operand = OperationData(name="weight", physical_other_operand = OperationData(
type=OperationDataType.PARAM, name="weight",
data=self.named_parameters['weight'], type=OperationDataType.PARAM,
logical_shape=self.named_parameters['weight'].shape) data=self.named_parameters["weight"],
logical_shape=self.named_parameters["weight"].shape,
)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
physical_running_mean_operand = OperationData(name="running_mean", physical_running_mean_operand = OperationData(
type=OperationDataType.BUFFER, name="running_mean",
data=self.named_buffers['running_mean'], type=OperationDataType.BUFFER,
logical_shape=self.named_buffers['running_mean'].shape) data=self.named_buffers["running_mean"],
logical_shape=self.named_buffers["running_mean"].shape,
)
physical_running_var_operand = OperationData(name="running_var", physical_running_var_operand = OperationData(
type=OperationDataType.BUFFER, name="running_var",
data=self.named_buffers['running_var'], type=OperationDataType.BUFFER,
logical_shape=self.named_buffers['running_var'].shape) data=self.named_buffers["running_var"],
logical_shape=self.named_buffers["running_var"].shape,
)
physical_num_batches_tracked_operand = OperationData( physical_num_batches_tracked_operand = OperationData(
name="num_batches_tracked", name="num_batches_tracked",
type=OperationDataType.BUFFER, type=OperationDataType.BUFFER,
data=self.named_buffers['num_batches_tracked'], data=self.named_buffers["num_batches_tracked"],
logical_shape=self.named_buffers['num_batches_tracked'].shape) logical_shape=self.named_buffers["num_batches_tracked"].shape,
)
mapping = { mapping = {
"input": physical_input_operand, "input": physical_input_operand,
...@@ -58,12 +65,12 @@ class BatchNormModuleHandler(MetaInfoModuleHandler): ...@@ -58,12 +65,12 @@ class BatchNormModuleHandler(MetaInfoModuleHandler):
"output": physical_output, "output": physical_output,
"running_mean": physical_running_mean_operand, "running_mean": physical_running_mean_operand,
"running_var": physical_running_var_operand, "running_var": physical_running_var_operand,
"num_batches_tracked": physical_num_batches_tracked_operand "num_batches_tracked": physical_num_batches_tracked_operand,
} }
if self.named_parameters['bias'] is not None: if self.named_parameters["bias"] is not None:
physical_bias_operand = OperationData(name="bias", physical_bias_operand = OperationData(
type=OperationDataType.PARAM, name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
data=self.named_parameters['bias']) )
mapping['bias'] = physical_bias_operand mapping["bias"] = physical_bias_operand
return mapping return mapping
...@@ -4,15 +4,14 @@ import torch ...@@ -4,15 +4,14 @@ import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..constants import BCAST_FUNC_OP from ..constants import BCAST_FUNC_OP
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import MetaInfoNodeHandler, NodeHandler from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
__all__ = ['BinaryElementwiseHandler'] __all__ = ["BinaryElementwiseHandler"]
@operator_registry.register(BCAST_FUNC_OP) @operator_registry.register(BCAST_FUNC_OP)
...@@ -38,7 +37,7 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler): ...@@ -38,7 +37,7 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
# The meta_data of node type argument could also possibly be a non-tensor object. # The meta_data of node type argument could also possibly be a non-tensor object.
if not isinstance(meta_data, torch.Tensor): if not isinstance(meta_data, torch.Tensor):
assert isinstance(meta_data, (int, float)) assert isinstance(meta_data, (int, float))
meta_data = torch.Tensor([meta_data]).to('meta') meta_data = torch.Tensor([meta_data]).to("meta")
non_tensor = True non_tensor = True
else: else:
...@@ -46,7 +45,7 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler): ...@@ -46,7 +45,7 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
# but we can deem it as meta data # but we can deem it as meta data
# as it won't affect the strategy generation # as it won't affect the strategy generation
assert isinstance(self.node.args[idx], (int, float)) assert isinstance(self.node.args[idx], (int, float))
meta_data = torch.Tensor([self.node.args[idx]]).to('meta') meta_data = torch.Tensor([self.node.args[idx]]).to("meta")
non_tensor = True non_tensor = True
return meta_data, non_tensor return meta_data, non_tensor
...@@ -58,24 +57,27 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler): ...@@ -58,24 +57,27 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
# and filter the non-tensor op_data in post_process. # and filter the non-tensor op_data in post_process.
self.non_tensor_list = [] self.non_tensor_list = []
# assert False # assert False
input_op_data = OperationData(name=str(self.node.args[0]), input_op_data = OperationData(
type=_get_op_data_type(input_meta_data), name=str(self.node.args[0]),
data=input_meta_data, type=_get_op_data_type(input_meta_data),
logical_shape=bcast_shape) data=input_meta_data,
other_op_data = OperationData(name=str(self.node.args[1]), logical_shape=bcast_shape,
type=_get_op_data_type(other_meta_data), )
data=other_meta_data, other_op_data = OperationData(
logical_shape=bcast_shape) name=str(self.node.args[1]),
output_op_data = OperationData(name=str(self.node), type=_get_op_data_type(other_meta_data),
type=OperationDataType.OUTPUT, data=other_meta_data,
data=output_meta_data, logical_shape=bcast_shape,
logical_shape=bcast_shape) )
output_op_data = OperationData(
name=str(self.node), type=OperationDataType.OUTPUT, data=output_meta_data, logical_shape=bcast_shape
)
if non_tensor_input: if non_tensor_input:
self.non_tensor_list.append(input_op_data) self.non_tensor_list.append(input_op_data)
if non_tensor_other: if non_tensor_other:
self.non_tensor_list.append(other_op_data) self.non_tensor_list.append(other_op_data)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data}
return mapping return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]: def get_strategy_generator(self) -> List[StrategyGenerator]:
...@@ -100,14 +102,14 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler): ...@@ -100,14 +102,14 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
logical_shape = op_data.logical_shape logical_shape = op_data.logical_shape
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
sharding_spec, logical_shape, physical_shape) sharding_spec, logical_shape, physical_shape
)
strategy.sharding_specs[op_data] = sharding_spec strategy.sharding_specs[op_data] = sharding_spec
if len(removed_dims) > 0: if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node, comm_action = comm_actions_for_oprands(
removed_dims=removed_dims, node=self.node, removed_dims=removed_dims, op_data=op_data, sharding_spec=sharding_spec
op_data=op_data, )
sharding_spec=sharding_spec)
strategy.communication_actions[op_data] = comm_action strategy.communication_actions[op_data] = comm_action
return strategy return strategy
...@@ -2,15 +2,13 @@ from typing import Dict, List, Union ...@@ -2,15 +2,13 @@ from typing import Dict, List, Union
import torch import torch
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler'] __all__ = ["BMMFunctionHandler", "AddBMMFunctionHandler"]
def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
...@@ -19,14 +17,14 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): ...@@ -19,14 +17,14 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
node handler to reduce code redundancy. node handler to reduce code redundancy.
""" """
# input operand # input operand
physical_input_operand = OperationData(name=str(node.args[input_idx]), physical_input_operand = OperationData(
type=OperationDataType.ARG, name=str(node.args[input_idx]), type=OperationDataType.ARG, data=node.args[input_idx]._meta_data
data=node.args[input_idx]._meta_data) )
# other operand # other operand
physical_other_operand = OperationData(name=str(node.args[other_idx]), physical_other_operand = OperationData(
type=OperationDataType.ARG, name=str(node.args[other_idx]), type=OperationDataType.ARG, data=node.args[other_idx]._meta_data
data=node.args[other_idx]._meta_data) )
# output # output
physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data) physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data)
...@@ -35,11 +33,13 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): ...@@ -35,11 +33,13 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
if bias_idx is not None: if bias_idx is not None:
# bias physical shape # bias physical shape
bias_logical_shape = node._meta_data.shape bias_logical_shape = node._meta_data.shape
physical_bias_operand = OperationData(name=str(node.args[bias_idx]), physical_bias_operand = OperationData(
type=OperationDataType.ARG, name=str(node.args[bias_idx]),
data=node.args[bias_idx]._meta_data, type=OperationDataType.ARG,
logical_shape=bias_logical_shape) data=node.args[bias_idx]._meta_data,
mapping['bias'] = physical_bias_operand logical_shape=bias_logical_shape,
)
mapping["bias"] = physical_bias_operand
return mapping return mapping
...@@ -91,20 +91,20 @@ class AddBMMFunctionHandler(NodeHandler): ...@@ -91,20 +91,20 @@ class AddBMMFunctionHandler(NodeHandler):
# convert bias from its logical sharding spec to its physical sharding spec # convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
if 'bias' in op_data_mapping: if "bias" in op_data_mapping:
bias_op_data = op_data_mapping['bias'] bias_op_data = op_data_mapping["bias"]
bias_physical_shape = bias_op_data.data.shape bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
bias_sharding_spec, bias_logical_shape, bias_physical_shape) bias_sharding_spec, bias_logical_shape, bias_physical_shape
)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0: if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node, comm_action = comm_actions_for_oprands(
removed_dims=removed_dims, node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec
op_data=bias_op_data, )
sharding_spec=bias_sharding_spec)
strategy.communication_actions[bias_op_data] = comm_action strategy.communication_actions[bias_op_data] = comm_action
return strategy return strategy
...@@ -3,13 +3,13 @@ from typing import Dict, List ...@@ -3,13 +3,13 @@ from typing import Dict, List
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import transpose_partition_dim from ..utils import transpose_partition_dim
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator from .strategy import ConvStrategyGenerator, StrategyGenerator
__all__ = ['ConvModuleHandler', 'ConvFunctionHandler'] __all__ = ["ConvModuleHandler", "ConvFunctionHandler"]
@operator_registry.register(torch.nn.Conv1d) @operator_registry.register(torch.nn.Conv1d)
...@@ -29,25 +29,29 @@ class ConvModuleHandler(MetaInfoModuleHandler): ...@@ -29,25 +29,29 @@ class ConvModuleHandler(MetaInfoModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies # use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process # the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
data=self.node.args[0]._meta_data) )
logical_shape_for_weight = list(self.named_parameters["weight"].shape) logical_shape_for_weight = list(self.named_parameters["weight"].shape)
logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ logical_shape_for_weight[0], logical_shape_for_weight[1] = (
1], logical_shape_for_weight[0] logical_shape_for_weight[1],
physical_other_operand = OperationData(name="weight", logical_shape_for_weight[0],
type=OperationDataType.PARAM, )
data=self.named_parameters['weight'], physical_other_operand = OperationData(
logical_shape=torch.Size(logical_shape_for_weight)) name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters["weight"],
logical_shape=torch.Size(logical_shape_for_weight),
)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if "bias" in self.named_parameters: if "bias" in self.named_parameters:
physical_bias_operand = OperationData(name="bias", physical_bias_operand = OperationData(
type=OperationDataType.PARAM, name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
data=self.named_parameters['bias']) )
mapping['bias'] = physical_bias_operand mapping["bias"] = physical_bias_operand
return mapping return mapping
def post_process(self, strategy: ShardingStrategy): def post_process(self, strategy: ShardingStrategy):
...@@ -77,9 +81,9 @@ class ConvFunctionHandler(MetaInfoNodeHandler): ...@@ -77,9 +81,9 @@ class ConvFunctionHandler(MetaInfoNodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies # use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process # the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
data=self.node.args[0]._meta_data) )
# check if the other operand is a parameter # check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
...@@ -88,26 +92,30 @@ class ConvFunctionHandler(MetaInfoNodeHandler): ...@@ -88,26 +92,30 @@ class ConvFunctionHandler(MetaInfoNodeHandler):
data_type = OperationDataType.ARG data_type = OperationDataType.ARG
logical_shape_for_weight = list(self.node.args[1]._meta_data.shape) logical_shape_for_weight = list(self.node.args[1]._meta_data.shape)
logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ logical_shape_for_weight[0], logical_shape_for_weight[1] = (
1], logical_shape_for_weight[0] logical_shape_for_weight[1],
physical_other_operand = OperationData(name=str(self.node.args[1]), logical_shape_for_weight[0],
type=data_type, )
data=self.node.args[1]._meta_data, physical_other_operand = OperationData(
logical_shape=torch.Size(logical_shape_for_weight)) name=str(self.node.args[1]),
type=data_type,
data=self.node.args[1]._meta_data,
logical_shape=torch.Size(logical_shape_for_weight),
)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None: if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None:
# check if the other operand is a parameter # check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM data_type = OperationDataType.PARAM
else: else:
data_type = OperationDataType.ARG data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), physical_bias_operand = OperationData(
type=data_type, name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data
data=self.node.kwargs["bias"]._meta_data) )
mapping['bias'] = physical_bias_operand mapping["bias"] = physical_bias_operand
return mapping return mapping
def post_process(self, strategy: ShardingStrategy): def post_process(self, strategy: ShardingStrategy):
......
...@@ -3,11 +3,11 @@ from typing import Dict, List ...@@ -3,11 +3,11 @@ from typing import Dict, List
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoNodeHandler, NodeHandler from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import DefaultReshapeGenerator, StrategyGenerator from .strategy import DefaultReshapeGenerator, StrategyGenerator
__all__ = ['DefaultReshapeHandler'] __all__ = ["DefaultReshapeHandler"]
@operator_registry.register(torch.flatten) @operator_registry.register(torch.flatten)
...@@ -54,17 +54,15 @@ class DefaultReshapeHandler(MetaInfoNodeHandler): ...@@ -54,17 +54,15 @@ class DefaultReshapeHandler(MetaInfoNodeHandler):
input_data = self.node.args[0]._meta_data input_data = self.node.args[0]._meta_data
input_logical_shape = self.infer_logical_shape(input_data) input_logical_shape = self.infer_logical_shape(input_data)
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(
type=data_type, name=str(self.node.args[0]), type=data_type, data=input_data, logical_shape=input_logical_shape
data=input_data, )
logical_shape=input_logical_shape)
output_data = self.node._meta_data output_data = self.node._meta_data
output_logical_shape = self.infer_logical_shape(output_data) output_logical_shape = self.infer_logical_shape(output_data)
physical_output = OperationData(name=str(self.node), physical_output = OperationData(
type=OperationDataType.OUTPUT, name=str(self.node), type=OperationDataType.OUTPUT, data=output_data, logical_shape=output_logical_shape
data=output_data, )
logical_shape=output_logical_shape)
mapping = {"input": physical_input_operand, "output": physical_output} mapping = {"input": physical_input_operand, "output": physical_output}
......
...@@ -12,11 +12,12 @@ from .node_handler import ModuleHandler, NodeHandler ...@@ -12,11 +12,12 @@ from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import EmbeddingStrategyGenerator, StrategyGenerator from .strategy import EmbeddingStrategyGenerator, StrategyGenerator
__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler'] __all__ = ["EmbeddingModuleHandler", "EmbeddingFunctionHandler"]
def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str, def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
output_name: str) -> List[ShardingStrategy]: strategy: ShardingStrategy, input_name: str, output_name: str
) -> List[ShardingStrategy]:
""" """
This function converts the logical sharding spec to the physical sharding spec for both the input and output This function converts the logical sharding spec to the physical sharding spec for both the input and output
of the embedding operation. of the embedding operation.
...@@ -56,27 +57,31 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ...@@ -56,27 +57,31 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy:
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try: try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
update_partition_dim(sharding_spec=input_sharding_spec, update_partition_dim(
dim_mapping={0: i}, sharding_spec=input_sharding_spec,
physical_shape=input_op_data.data.shape, dim_mapping={0: i},
inplace=True) physical_shape=input_op_data.data.shape,
inplace=True,
)
if last_logical_output_dims in output_sharding_spec.dim_partition_dict: if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims} dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims}
else: else:
dim_mapping = {0: i} dim_mapping = {0: i}
update_partition_dim(sharding_spec=output_sharding_spec, update_partition_dim(
dim_mapping=dim_mapping, sharding_spec=output_sharding_spec,
physical_shape=output_op_data.data.shape, dim_mapping=dim_mapping,
inplace=True) physical_shape=output_op_data.data.shape,
inplace=True,
)
strategy_copy.name = f'{strategy.name}_{i}' strategy_copy.name = f"{strategy.name}_{i}"
sharding_strategies.append(strategy_copy) sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e: except ShardingNotDivisibleError as e:
logger.debug( logger.debug(
f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}"
) )
else: else:
# the generated sharding strategy does not shard the non-matrix dimension, # the generated sharding strategy does not shard the non-matrix dimension,
...@@ -87,20 +92,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ...@@ -87,20 +92,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy:
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape # after updating, the logical shape will be replaced by the physical shape
update_partition_dim(sharding_spec=input_sharding_spec, update_partition_dim(
dim_mapping={}, sharding_spec=input_sharding_spec, dim_mapping={}, physical_shape=input_op_data.data.shape, inplace=True
physical_shape=input_op_data.data.shape, )
inplace=True)
if last_logical_output_dims in output_sharding_spec.dim_partition_dict: if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {last_logical_output_dims: last_physical_output_dims} dim_mapping = {last_logical_output_dims: last_physical_output_dims}
else: else:
dim_mapping = {} dim_mapping = {}
update_partition_dim(sharding_spec=output_sharding_spec, update_partition_dim(
dim_mapping=dim_mapping, sharding_spec=output_sharding_spec,
physical_shape=output_op_data.data.shape, dim_mapping=dim_mapping,
inplace=True) physical_shape=output_op_data.data.shape,
inplace=True,
)
sharding_strategies.append(strategy_copy) sharding_strategies.append(strategy_copy)
return sharding_strategies return sharding_strategies
...@@ -125,14 +131,16 @@ class EmbeddingModuleHandler(ModuleHandler): ...@@ -125,14 +131,16 @@ class EmbeddingModuleHandler(ModuleHandler):
# Finally, the input will be transformed back to its original shape in self.post_process # Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape input_logical_shape = input_meta_data.view(-1).shape
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]),
data=input_meta_data, type=OperationDataType.ARG,
logical_shape=input_logical_shape) data=input_meta_data,
logical_shape=input_logical_shape,
)
physical_other_operand = OperationData(name="weight", physical_other_operand = OperationData(
type=OperationDataType.PARAM, name="weight", type=OperationDataType.PARAM, data=self.named_parameters["weight"]
data=self.named_parameters['weight']) )
# Same as input, in nn.Embedding operation, all the dimensions of output will be treated as # Same as input, in nn.Embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based # (batch dimension, embedding dimension), and then the sharding spec will be generated based
...@@ -141,10 +149,12 @@ class EmbeddingModuleHandler(ModuleHandler): ...@@ -141,10 +149,12 @@ class EmbeddingModuleHandler(ModuleHandler):
# Finally, the output will be transformed back to its original shape in self.post_process # Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data = self.node._meta_data output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(name=str(self.node), physical_output = OperationData(
type=OperationDataType.OUTPUT, name=str(self.node),
data=output_meta_data, type=OperationDataType.OUTPUT,
logical_shape=output_logical_shape) data=output_meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
...@@ -157,10 +167,9 @@ class EmbeddingModuleHandler(ModuleHandler): ...@@ -157,10 +167,9 @@ class EmbeddingModuleHandler(ModuleHandler):
# create multiple sharding strategies for the inputs # create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D, # as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
input_name=str( strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
self.node.args[0]), )
output_name=str(self.node))
return strategies return strategies
...@@ -183,10 +192,12 @@ class EmbeddingFunctionHandler(NodeHandler): ...@@ -183,10 +192,12 @@ class EmbeddingFunctionHandler(NodeHandler):
# Finally, the input will be transformed back to its original shape in self.post_process # Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape input_logical_shape = input_meta_data.view(-1).shape
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]),
data=self.node.args[0]._meta_data, type=OperationDataType.ARG,
logical_shape=input_logical_shape) data=self.node.args[0]._meta_data,
logical_shape=input_logical_shape,
)
# check if the other operand is a parameter # check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
...@@ -194,9 +205,9 @@ class EmbeddingFunctionHandler(NodeHandler): ...@@ -194,9 +205,9 @@ class EmbeddingFunctionHandler(NodeHandler):
else: else:
data_type = OperationDataType.ARG data_type = OperationDataType.ARG
physical_other_operand = OperationData(name=str(self.node.args[1]), physical_other_operand = OperationData(
type=data_type, name=str(self.node.args[1]), type=data_type, data=self.node.args[1]._meta_data
data=self.node.args[1]._meta_data) )
# Same as input, in F.embedding operation, all the dimensions of output will be treated as # Same as input, in F.embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based # (batch dimension, embedding dimension), and then the sharding spec will be generated based
...@@ -223,8 +234,7 @@ class EmbeddingFunctionHandler(NodeHandler): ...@@ -223,8 +234,7 @@ class EmbeddingFunctionHandler(NodeHandler):
# create multiple sharding strategies for the inputs # create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D, # as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
input_name=str( strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
self.node.args[0]), )
output_name=str(self.node))
return strategies return strategies
...@@ -4,7 +4,7 @@ from ..sharding_strategy import OperationData, OperationDataType ...@@ -4,7 +4,7 @@ from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .strategy import GetattrGenerator, StrategyGenerator from .strategy import GetattrGenerator, StrategyGenerator
__all__ = ['GetattrHandler'] __all__ = ["GetattrHandler"]
class GetattrHandler(NodeHandler): class GetattrHandler(NodeHandler):
......
...@@ -8,7 +8,7 @@ from .node_handler import NodeHandler ...@@ -8,7 +8,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
__all__ = ['GetItemHandler'] __all__ = ["GetItemHandler"]
@operator_registry.register(operator.getitem) @operator_registry.register(operator.getitem)
...@@ -30,9 +30,9 @@ class GetItemHandler(NodeHandler): ...@@ -30,9 +30,9 @@ class GetItemHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies # use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process # the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
data=self.node.args[0]._meta_data) )
physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1]) physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
......
...@@ -3,11 +3,11 @@ from typing import Dict, List ...@@ -3,11 +3,11 @@ from typing import Dict, List
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoModuleHandler, ModuleHandler from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import LayerNormGenerator, StrategyGenerator from .strategy import LayerNormGenerator, StrategyGenerator
__all__ = ['LayerNormModuleHandler'] __all__ = ["LayerNormModuleHandler"]
@operator_registry.register(torch.nn.LayerNorm) @operator_registry.register(torch.nn.LayerNorm)
...@@ -25,20 +25,22 @@ class LayerNormModuleHandler(MetaInfoModuleHandler): ...@@ -25,20 +25,22 @@ class LayerNormModuleHandler(MetaInfoModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies # use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process # the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
data=self.node.args[0]._meta_data) )
physical_other_operand = OperationData(name="weight", physical_other_operand = OperationData(
type=OperationDataType.PARAM, name="weight",
data=self.named_parameters['weight'], type=OperationDataType.PARAM,
logical_shape=self.named_parameters['weight'].shape) data=self.named_parameters["weight"],
logical_shape=self.named_parameters["weight"].shape,
)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.named_parameters['bias'] is not None: if self.named_parameters["bias"] is not None:
physical_bias_operand = OperationData(name="bias", physical_bias_operand = OperationData(
type=OperationDataType.PARAM, name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
data=self.named_parameters['bias']) )
mapping['bias'] = physical_bias_operand mapping["bias"] = physical_bias_operand
return mapping return mapping
...@@ -3,24 +3,21 @@ from typing import Dict, List, Union ...@@ -3,24 +3,21 @@ from typing import Dict, List, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import ( from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
check_sharding_spec_validity,
transpose_partition_dim,
update_partition_dim,
)
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] __all__ = ["LinearModuleHandler", "LinearFunctionHandler"]
def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy, def _update_sharding_spec_for_transposed_weight_for_linear(
weight_name: str) -> ShardingStrategy: strategy: ShardingStrategy, weight_name: str
) -> ShardingStrategy:
""" """
This function is a helper function used by both module node handler and function node handler. This function will This function is a helper function used by both module node handler and function node handler. This function will
convert the sharding spec for the transposed weight to the correct partition spec. convert the sharding spec for the transposed weight to the correct partition spec.
...@@ -32,16 +29,17 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr ...@@ -32,16 +29,17 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
# switch the dimensions of the transposed weight # switch the dimensions of the transposed weight
sharding_spec = strategy.get_sharding_spec_by_name(weight_name) sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
op_data = strategy.get_op_data_by_name(weight_name) op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape[0] == op_data.data.shape[1] and \ assert (
op_data.logical_shape[1] == op_data.data.shape[0], \ op_data.logical_shape[0] == op_data.data.shape[1] and op_data.logical_shape[1] == op_data.data.shape[0]
"Expected the logical shape of the linear operator's weight is equal to transposed physical shape" ), "Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
dim_size = len(op_data.logical_shape) dim_size = len(op_data.logical_shape)
transpose_partition_dim(sharding_spec, 0, dim_size - 1) transpose_partition_dim(sharding_spec, 0, dim_size - 1)
return strategy return strategy
def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str, def _convert_logical_sharding_to_physical_sharding_spec_for_linear(
output_name: str) -> List[ShardingStrategy]: strategy: ShardingStrategy, input_name: str, output_name: str
) -> List[ShardingStrategy]:
""" """
This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output
should have the same sharding spec. should have the same sharding spec.
...@@ -99,22 +97,26 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha ...@@ -99,22 +97,26 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
input_dim_mapping = {0: i} input_dim_mapping = {0: i}
input_dim_mapping.update(input_last_dim_mapping) input_dim_mapping.update(input_last_dim_mapping)
update_partition_dim(sharding_spec=input_sharding_spec, update_partition_dim(
dim_mapping=input_dim_mapping, sharding_spec=input_sharding_spec,
physical_shape=input_op_data.data.shape, dim_mapping=input_dim_mapping,
inplace=True) physical_shape=input_op_data.data.shape,
inplace=True,
)
output_dim_mapping = {0: i} output_dim_mapping = {0: i}
output_dim_mapping.update(output_last_dim_mapping) output_dim_mapping.update(output_last_dim_mapping)
update_partition_dim(sharding_spec=output_sharding_spec, update_partition_dim(
dim_mapping=output_dim_mapping, sharding_spec=output_sharding_spec,
physical_shape=output_op_data.data.shape, dim_mapping=output_dim_mapping,
inplace=True) physical_shape=output_op_data.data.shape,
strategy_copy.name = f'{strategy.name}_{i}' inplace=True,
)
strategy_copy.name = f"{strategy.name}_{i}"
sharding_strategies.append(strategy_copy) sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e: except ShardingNotDivisibleError as e:
logger.debug( logger.debug(
f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}"
) )
else: else:
# the generated sharding strategy does not shard the non-matrix dimension, # the generated sharding strategy does not shard the non-matrix dimension,
...@@ -127,17 +129,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha ...@@ -127,17 +129,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
# after updating, the logical shape will be replaced by the physical shape # after updating, the logical shape will be replaced by the physical shape
input_dim_mapping = {} input_dim_mapping = {}
input_dim_mapping.update(input_last_dim_mapping) input_dim_mapping.update(input_last_dim_mapping)
update_partition_dim(sharding_spec=input_sharding_spec, update_partition_dim(
dim_mapping=input_dim_mapping, sharding_spec=input_sharding_spec,
physical_shape=input_op_data.data.shape, dim_mapping=input_dim_mapping,
inplace=True) physical_shape=input_op_data.data.shape,
inplace=True,
)
output_dim_mapping = {} output_dim_mapping = {}
output_dim_mapping.update(output_last_dim_mapping) output_dim_mapping.update(output_last_dim_mapping)
update_partition_dim(sharding_spec=output_sharding_spec, update_partition_dim(
dim_mapping=output_dim_mapping, sharding_spec=output_sharding_spec,
physical_shape=output_op_data.data.shape, dim_mapping=output_dim_mapping,
inplace=True) physical_shape=output_op_data.data.shape,
inplace=True,
)
sharding_strategies.append(strategy_copy) sharding_strategies.append(strategy_copy)
return sharding_strategies return sharding_strategies
...@@ -152,10 +158,13 @@ class LinearModuleHandler(MetaInfoModuleHandler): ...@@ -152,10 +158,13 @@ class LinearModuleHandler(MetaInfoModuleHandler):
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append( generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, LinearProjectionStrategyGenerator(
self.device_mesh, op_data_mapping,
linear_projection_type='linear', self.device_mesh,
solver_perference=self.solver_perference)) linear_projection_type="linear",
solver_perference=self.solver_perference,
)
)
return generators return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
...@@ -163,28 +172,34 @@ class LinearModuleHandler(MetaInfoModuleHandler): ...@@ -163,28 +172,34 @@ class LinearModuleHandler(MetaInfoModuleHandler):
# the strategies will be transformed back to its original shape in self.post_process # the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]),
data=input_meta_data, type=OperationDataType.ARG,
logical_shape=input_logical_shape) data=input_meta_data,
physical_other_operand = OperationData(name="weight", logical_shape=input_logical_shape,
type=OperationDataType.PARAM, )
data=self.named_parameters['weight'], physical_other_operand = OperationData(
logical_shape=self.named_parameters['weight'].shape[::-1]) name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters["weight"],
logical_shape=self.named_parameters["weight"].shape[::-1],
)
output_meta_data = self.node._meta_data output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(name=str(self.node), physical_output = OperationData(
type=OperationDataType.OUTPUT, name=str(self.node),
data=output_meta_data, type=OperationDataType.OUTPUT,
logical_shape=output_logical_shape) data=output_meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if 'bias' in self.named_parameters is not None: if "bias" in self.named_parameters is not None:
physical_bias_operand = OperationData(name="bias", physical_bias_operand = OperationData(
type=OperationDataType.PARAM, name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
data=self.named_parameters['bias']) )
mapping['bias'] = physical_bias_operand mapping["bias"] = physical_bias_operand
return mapping return mapping
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
...@@ -194,14 +209,14 @@ class LinearModuleHandler(MetaInfoModuleHandler): ...@@ -194,14 +209,14 @@ class LinearModuleHandler(MetaInfoModuleHandler):
2. the input and output sharding specs are updated to physical shape. 2. the input and output sharding specs are updated to physical shape.
""" """
# switch the dimensions of the transposed weight # switch the dimensions of the transposed weight
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight') strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name="weight")
# create multiple sharding strategies for the inputs # create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D, # as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input # we need to map the partition at dim 0 to one of the first few dimensions of the input
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(
input_name=str(self.node.args[0]), strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
output_name=str(self.node)) )
return strategies return strategies
...@@ -215,7 +230,8 @@ class LinearFunctionHandler(MetaInfoNodeHandler): ...@@ -215,7 +230,8 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append( generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear")
)
return generators return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
...@@ -223,10 +239,12 @@ class LinearFunctionHandler(MetaInfoNodeHandler): ...@@ -223,10 +239,12 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
# the strategies will be transformed back to its original shape in self.post_process # the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]),
data=self.node.args[0]._meta_data, type=OperationDataType.ARG,
logical_shape=input_logical_shape) data=self.node.args[0]._meta_data,
logical_shape=input_logical_shape,
)
# check if the other operand is a parameter # check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
...@@ -234,10 +252,12 @@ class LinearFunctionHandler(MetaInfoNodeHandler): ...@@ -234,10 +252,12 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
else: else:
data_type = OperationDataType.ARG data_type = OperationDataType.ARG
physical_other_operand = OperationData(name=str(self.node.args[1]), physical_other_operand = OperationData(
type=data_type, name=str(self.node.args[1]),
data=self.node.args[1]._meta_data, type=data_type,
logical_shape=self.node.args[1]._meta_data.shape[::-1]) data=self.node.args[1]._meta_data,
logical_shape=self.node.args[1]._meta_data.shape[::-1],
)
output_meta_data = self.node._meta_data output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData( physical_output = OperationData(
...@@ -249,27 +269,28 @@ class LinearFunctionHandler(MetaInfoNodeHandler): ...@@ -249,27 +269,28 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None: if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None:
# check if the other operand is a parameter # check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM data_type = OperationDataType.PARAM
else: else:
data_type = OperationDataType.ARG data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), physical_bias_operand = OperationData(
type=data_type, name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data
data=self.node.kwargs["bias"]._meta_data) )
mapping['bias'] = physical_bias_operand mapping["bias"] = physical_bias_operand
return mapping return mapping
def post_process(self, strategy: ShardingStrategy): def post_process(self, strategy: ShardingStrategy):
# switch the dimensions of the transposed weight # switch the dimensions of the transposed weight
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, strategy = _update_sharding_spec_for_transposed_weight_for_linear(
weight_name=str(self.node.args[1])) strategy=strategy, weight_name=str(self.node.args[1])
)
# create multiple sharding strategies for the inputs # create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D, # as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input # we need to map the partition at dim 0 to one of the first few dimensions of the input
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(
input_name=str(self.node.args[0]), strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
output_name=str(self.node)) )
return strategies return strategies
...@@ -16,7 +16,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException ...@@ -16,7 +16,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import MetaInfoNodeHandler, NodeHandler from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import ( from .strategy import (
BatchedMatMulStrategyGenerator, BatchedMatMulStrategyGenerator,
...@@ -37,6 +37,7 @@ class MatMulType(Enum): ...@@ -37,6 +37,7 @@ class MatMulType(Enum):
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
""" """
DOT = 0 DOT = 0
MM = 1 MM = 1
MV = 2 MV = 2
...@@ -92,26 +93,26 @@ class Padder(BmmTransform): ...@@ -92,26 +93,26 @@ class Padder(BmmTransform):
def apply(self, shape_mapping: Dict[str, List[int]]): def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = deepcopy(shape_mapping) mapping_copy = deepcopy(shape_mapping)
input_shape = mapping_copy['input'] input_shape = mapping_copy["input"]
other_shape = mapping_copy['other'] other_shape = mapping_copy["other"]
if len(input_shape) == 1: if len(input_shape) == 1:
# if the input is a 1D tensor, 1 is prepended to its shape # if the input is a 1D tensor, 1 is prepended to its shape
# and it will be removed afterwards # and it will be removed afterwards
input_shape.insert(0, 1) input_shape.insert(0, 1)
self.padded_dim_mapping['input'] = -2 self.padded_dim_mapping["input"] = -2
self.padded_dim_mapping['output'] = -2 self.padded_dim_mapping["output"] = -2
elif len(other_shape) == 1: elif len(other_shape) == 1:
# if the other is a 1D tensor, 1 is appended to its shape # if the other is a 1D tensor, 1 is appended to its shape
# and it will be removed afterwards # and it will be removed afterwards
other_shape = other_shape.append(1) other_shape = other_shape.append(1)
self.padded_dim_mapping['other'] = -1 self.padded_dim_mapping["other"] = -1
self.padded_dim_mapping['output'] = -1 self.padded_dim_mapping["output"] = -1
return mapping_copy return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
input_op_data = op_data_mapping['input'] op_data_mapping["input"]
other_op_data = op_data_mapping['other'] op_data_mapping["other"]
def _remove_padded_dim(key, strategy): def _remove_padded_dim(key, strategy):
op_data = op_data_mapping[key] op_data = op_data_mapping[key]
...@@ -131,7 +132,7 @@ class Padder(BmmTransform): ...@@ -131,7 +132,7 @@ class Padder(BmmTransform):
# compute unpadded tensor shape # compute unpadded tensor shape
tensor_shape.pop(padded_dim) tensor_shape.pop(padded_dim)
assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}' assert tensor_shape == list(op_data.data.shape), f"{tensor_shape} vs {list(op_data.data.shape)}"
# update sharding spec # update sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list) sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
...@@ -142,15 +143,15 @@ class Padder(BmmTransform): ...@@ -142,15 +143,15 @@ class Padder(BmmTransform):
strategy_copy = strategy.clone() strategy_copy = strategy.clone()
# only one of input and other will be padded # only one of input and other will be padded
if 'input' in self.padded_dim_mapping: if "input" in self.padded_dim_mapping:
_remove_padded_dim('input', strategy_copy) _remove_padded_dim("input", strategy_copy)
_remove_padded_dim('output', strategy_copy) _remove_padded_dim("output", strategy_copy)
elif 'other' in self.padded_dim_mapping: elif "other" in self.padded_dim_mapping:
_remove_padded_dim('other', strategy_copy) _remove_padded_dim("other", strategy_copy)
_remove_padded_dim('output', strategy_copy) _remove_padded_dim("output", strategy_copy)
strategies.append(strategy_copy) strategies.append(strategy_copy)
except ShardingSpecException as e: except ShardingSpecException:
pass pass
return strategies return strategies
...@@ -167,8 +168,8 @@ class Broadcaster(BmmTransform): ...@@ -167,8 +168,8 @@ class Broadcaster(BmmTransform):
mapping_copy = shape_mapping.copy() mapping_copy = shape_mapping.copy()
# get shapes # get shapes
input_shape = mapping_copy['input'] input_shape = mapping_copy["input"]
other_shape = mapping_copy['other'] other_shape = mapping_copy["other"]
# sanity check # sanity check
assert len(input_shape) > 1 and len(other_shape) > 1 assert len(input_shape) > 1 and len(other_shape) > 1
...@@ -179,16 +180,16 @@ class Broadcaster(BmmTransform): ...@@ -179,16 +180,16 @@ class Broadcaster(BmmTransform):
# store the broadcast dim info # store the broadcast dim info
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2]) input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2]) other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
self.broadcast_dim_info['input'] = input_broadcast_dim_info self.broadcast_dim_info["input"] = input_broadcast_dim_info
self.broadcast_dim_info['other'] = other_broadcast_dim_info self.broadcast_dim_info["other"] = other_broadcast_dim_info
# create the full logical shape # create the full logical shape
input_shape = bcast_non_matrix_dims + input_shape[-2:] input_shape = bcast_non_matrix_dims + input_shape[-2:]
other_shape = bcast_non_matrix_dims + other_shape[-2:] other_shape = bcast_non_matrix_dims + other_shape[-2:]
assert len(input_shape) == len(other_shape) assert len(input_shape) == len(other_shape)
mapping_copy['input'] = input_shape mapping_copy["input"] = input_shape
mapping_copy['other'] = other_shape mapping_copy["other"] = other_shape
return mapping_copy return mapping_copy
...@@ -216,17 +217,18 @@ class Broadcaster(BmmTransform): ...@@ -216,17 +217,18 @@ class Broadcaster(BmmTransform):
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec=sharding_spec, logical_sharding_spec=sharding_spec,
logical_shape=sharding_spec.entire_shape, logical_shape=sharding_spec.entire_shape,
physical_shape=tensor_shape_before_broadcast) physical_shape=tensor_shape_before_broadcast,
)
strategy.sharding_specs[op_data] = physical_sharding_spec strategy.sharding_specs[op_data] = physical_sharding_spec
# enumerate all sharding strategies # enumerate all sharding strategies
strategies = [] strategies = []
try: try:
strategy_copy = strategy.clone() strategy_copy = strategy.clone()
_remove_sharding_on_broadcast_dim('input', strategy_copy) _remove_sharding_on_broadcast_dim("input", strategy_copy)
_remove_sharding_on_broadcast_dim('other', strategy_copy) _remove_sharding_on_broadcast_dim("other", strategy_copy)
strategies.append(strategy_copy) strategies.append(strategy_copy)
except ShardingSpecException as e: except ShardingSpecException:
pass pass
return strategies return strategies
...@@ -241,20 +243,20 @@ class Viewer(BmmTransform): ...@@ -241,20 +243,20 @@ class Viewer(BmmTransform):
def apply(self, shape_mapping: Dict[str, List[int]]): def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy() mapping_copy = shape_mapping.copy()
self.batch_dims_before_view = list(mapping_copy['input'][:-2]) self.batch_dims_before_view = list(mapping_copy["input"][:-2])
# get shapes # get shapes
input_shape = shape_mapping['input'] input_shape = shape_mapping["input"]
other_shape = shape_mapping['other'] other_shape = shape_mapping["other"]
# view to 3d tensor # view to 3d tensor
assert len(input_shape) >= 3 and len(other_shape) >= 3 assert len(input_shape) >= 3 and len(other_shape) >= 3
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:] input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:] other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
output_shape = input_shape[:2] + other_shape[2:] output_shape = input_shape[:2] + other_shape[2:]
mapping_copy['input'] = input_shape mapping_copy["input"] = input_shape
mapping_copy['other'] = other_shape mapping_copy["other"] = other_shape
mapping_copy['output'] = output_shape mapping_copy["output"] = output_shape
return mapping_copy return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
...@@ -291,11 +293,11 @@ class Viewer(BmmTransform): ...@@ -291,11 +293,11 @@ class Viewer(BmmTransform):
# create a new strategy # create a new strategy
strategy_copy = strategy.clone() strategy_copy = strategy.clone()
try: try:
_update_sharding_spec('input', strategy_copy, i) _update_sharding_spec("input", strategy_copy, i)
_update_sharding_spec('other', strategy_copy, i) _update_sharding_spec("other", strategy_copy, i)
_update_sharding_spec('output', strategy_copy, i) _update_sharding_spec("output", strategy_copy, i)
strategies.append(strategy_copy) strategies.append(strategy_copy)
except ShardingSpecException as e: except ShardingSpecException:
continue continue
return strategies return strategies
...@@ -312,14 +314,14 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms): ...@@ -312,14 +314,14 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms):
3. reshape to 3 dimensions 3. reshape to 3 dimensions
""" """
shape_mapping = {'input': input_shape, 'other': other_shape} shape_mapping = {"input": input_shape, "other": other_shape}
for transform in transforms: for transform in transforms:
shape_mapping = transform.apply(shape_mapping) shape_mapping = transform.apply(shape_mapping)
input_shape = shape_mapping.get('input', None) input_shape = shape_mapping.get("input", None)
other_shape = shape_mapping.get('other', None) other_shape = shape_mapping.get("other", None)
output_shape = shape_mapping.get('output', None) output_shape = shape_mapping.get("output", None)
return input_shape, other_shape, output_shape return input_shape, other_shape, output_shape
...@@ -364,7 +366,8 @@ class MatMulHandler(MetaInfoNodeHandler): ...@@ -364,7 +366,8 @@ class MatMulHandler(MetaInfoNodeHandler):
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh)) generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MM: elif self.matmul_type == MatMulType.MM:
generators.append( generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear")
)
return generators return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
...@@ -372,7 +375,7 @@ class MatMulHandler(MetaInfoNodeHandler): ...@@ -372,7 +375,7 @@ class MatMulHandler(MetaInfoNodeHandler):
MatMulType.DOT: self._get_logical_shape_for_dot, MatMulType.DOT: self._get_logical_shape_for_dot,
MatMulType.MM: self._get_logical_shape_for_mm, MatMulType.MM: self._get_logical_shape_for_mm,
MatMulType.MV: self._get_logical_shape_for_mv, MatMulType.MV: self._get_logical_shape_for_mv,
MatMulType.BMM: self._get_logical_shape_for_bmm MatMulType.BMM: self._get_logical_shape_for_bmm,
} }
logical_shapes = logical_shape_func[self.matmul_type]() logical_shapes = logical_shape_func[self.matmul_type]()
op_data_mapping = self._get_op_data_mapping(*logical_shapes) op_data_mapping = self._get_op_data_mapping(*logical_shapes)
...@@ -390,20 +393,26 @@ class MatMulHandler(MetaInfoNodeHandler): ...@@ -390,20 +393,26 @@ class MatMulHandler(MetaInfoNodeHandler):
output_logical_shape = torch.Size(output_logical_shape) output_logical_shape = torch.Size(output_logical_shape)
# create op data # create op data
input_op_data = OperationData(name=str(self.node.args[0]), input_op_data = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]),
data=self.input_meta_data, type=OperationDataType.ARG,
logical_shape=input_logical_shape) data=self.input_meta_data,
other_op_data = OperationData(name=str(self.node.args[1]), logical_shape=input_logical_shape,
type=OperationDataType.ARG, )
data=self.other_meta_data, other_op_data = OperationData(
logical_shape=other_logical_shape) name=str(self.node.args[1]),
output_op_data = OperationData(name=str(self.node), type=OperationDataType.ARG,
type=OperationDataType.OUTPUT, data=self.other_meta_data,
data=self.output_meta_data, logical_shape=other_logical_shape,
logical_shape=output_logical_shape) )
output_op_data = OperationData(
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.output_meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data}
return mapping return mapping
def _get_logical_shape_for_dot(self): def _get_logical_shape_for_dot(self):
...@@ -460,9 +469,11 @@ class MatMulHandler(MetaInfoNodeHandler): ...@@ -460,9 +469,11 @@ class MatMulHandler(MetaInfoNodeHandler):
dim_partition_dict[0] = shard dim_partition_dict[0] = shard
# re-init the sharding spec # re-init the sharding spec
input_sharding_spec.__init__(input_sharding_spec.device_mesh, input_sharding_spec.__init__(
entire_shape=input_physical_shape, input_sharding_spec.device_mesh,
dim_partition_dict=dim_partition_dict) entire_shape=input_physical_shape,
dim_partition_dict=dim_partition_dict,
)
return strategy return strategy
else: else:
return strategy return strategy
...@@ -481,7 +492,8 @@ class MatMulHandler(MetaInfoNodeHandler): ...@@ -481,7 +492,8 @@ class MatMulHandler(MetaInfoNodeHandler):
recovered_stragies.extend(output) recovered_stragies.extend(output)
else: else:
raise TypeError( raise TypeError(
f"Found unexpected output type {type(output)} from the recover method of BmmTransform") f"Found unexpected output type {type(output)} from the recover method of BmmTransform"
)
strategies = recovered_stragies strategies = recovered_stragies
for index, strategies in enumerate(strategies): for index, strategies in enumerate(strategies):
strategies.name = f"{strategies.name}_{index}" strategies.name = f"{strategies.name}_{index}"
......
...@@ -8,7 +8,6 @@ from colossalai.auto_parallel.meta_profiler.shard_metainfo import ShardMetaInfo, ...@@ -8,7 +8,6 @@ from colossalai.auto_parallel.meta_profiler.shard_metainfo import ShardMetaInfo,
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
OperationDataType,
ShardingSpec, ShardingSpec,
ShardingStrategy, ShardingStrategy,
StrategiesVector, StrategiesVector,
...@@ -23,21 +22,23 @@ from .strategy import StrategyGenerator ...@@ -23,21 +22,23 @@ from .strategy import StrategyGenerator
class NodeHandler(ABC): class NodeHandler(ABC):
''' """
The NodeHandler is an abstract class used to generate every possible strategies for an operator node. The NodeHandler is an abstract class used to generate every possible strategies for an operator node.
Args: Args:
node (Node): the input node in node argument list. node (Node): the input node in node argument list.
device_mesh (DeviceMesh): A logical view of a physical mesh. device_mesh (DeviceMesh): A logical view of a physical mesh.
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
''' """
def __init__(self, def __init__(
node: Node, self,
device_mesh: DeviceMesh, node: Node,
strategies_vector: StrategiesVector, device_mesh: DeviceMesh,
shard_option: ShardOption = ShardOption.STANDARD, strategies_vector: StrategiesVector,
solver_perference: SolverPerference = SolverPerference.STANDARD) -> None: shard_option: ShardOption = ShardOption.STANDARD,
solver_perference: SolverPerference = SolverPerference.STANDARD,
) -> None:
self.node = node self.node = node
self.predecessor_node = list(node._input_nodes.keys()) self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys()) self.successor_node = list(node.users.keys())
...@@ -68,8 +69,9 @@ class NodeHandler(ABC): ...@@ -68,8 +69,9 @@ class NodeHandler(ABC):
current_sharding_spec = strategy.sharding_specs[op_data] current_sharding_spec = strategy.sharding_specs[op_data]
# get the sharding specs for this node generated # get the sharding specs for this node generated
# in its own node handler # in its own node handler
assert hasattr(node, 'strategies_vector'), \ assert hasattr(
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.' node, "strategies_vector"
), f"The predecessor node {node_name} has no strategy vector to compute the resharding cost."
prev_strategy_vector = node.strategies_vector prev_strategy_vector = node.strategies_vector
prev_sharding_specs = [ prev_sharding_specs = [
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
...@@ -80,10 +82,10 @@ class NodeHandler(ABC): ...@@ -80,10 +82,10 @@ class NodeHandler(ABC):
resharding_costs[node] = [] resharding_costs[node] = []
def _compute_resharding_cost( def _compute_resharding_cost(
prev_sharding_spec: Union[ShardingSpec, prev_sharding_spec: Union[ShardingSpec, List[ShardingSpec]],
List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec, current_sharding_spec: Union[ShardingSpec, List[ShardingSpec]],
List[ShardingSpec]], data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem: ) -> TrainCycleItem:
""" """
This is a helper function to compute the resharding cost for a specific strategy of a node. This is a helper function to compute the resharding cost for a specific strategy of a node.
""" """
...@@ -94,30 +96,35 @@ class NodeHandler(ABC): ...@@ -94,30 +96,35 @@ class NodeHandler(ABC):
dtype = data.dtype dtype = data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
_, _, consistency_cost = shape_consistency_manager.shape_consistency( _, _, consistency_cost = shape_consistency_manager.shape_consistency(
prev_sharding_spec, current_sharding_spec) prev_sharding_spec, current_sharding_spec
)
resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
bwd=consistency_cost["backward"] * size_per_elem_bytes, resharding_cost = TrainCycleItem(
total=consistency_cost["total"] * size_per_elem_bytes) fwd=consistency_cost["forward"] * size_per_elem_bytes,
bwd=consistency_cost["backward"] * size_per_elem_bytes,
total=consistency_cost["total"] * size_per_elem_bytes,
)
return resharding_cost return resharding_cost
else: else:
# This raise is used to check if we have missed any type of data. # This raise is used to check if we have missed any type of data.
# It could be merged into Parameter branch, which means we won't handle # It could be merged into Parameter branch, which means we won't handle
# non-tensor arguments. # non-tensor arguments.
raise ValueError(f'Unsupported data type {type(data)}') raise ValueError(f"Unsupported data type {type(data)}")
else: else:
assert isinstance(prev_sharding_spec, (tuple, list)), \ assert isinstance(
f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \ prev_sharding_spec, (tuple, list)
or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}' ), f"prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}"
fwd_cost = 0 fwd_cost = 0
bwd_cost = 0 bwd_cost = 0
total_cost = 0 total_cost = 0
for index, (prev_sharding_spec_item, for index, (prev_sharding_spec_item, current_sharding_spec_item) in enumerate(
current_sharding_spec_item) in enumerate(zip(prev_sharding_spec, zip(prev_sharding_spec, current_sharding_spec)
current_sharding_spec)): ):
item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item, item_cost = _compute_resharding_cost(
data[index]) prev_sharding_spec_item, current_sharding_spec_item, data[index]
)
fwd_cost += item_cost.fwd fwd_cost += item_cost.fwd
bwd_cost += item_cost.bwd bwd_cost += item_cost.bwd
total_cost += item_cost.total total_cost += item_cost.total
...@@ -138,17 +145,17 @@ class NodeHandler(ABC): ...@@ -138,17 +145,17 @@ class NodeHandler(ABC):
This function is used to get the target function for the node handler. This function is used to get the target function for the node handler.
The target function is used to analyze the costs of strategies. The target function is used to analyze the costs of strategies.
""" """
if self.node.op in ('placeholder', 'get_attr', 'output'): if self.node.op in ("placeholder", "get_attr", "output"):
return None return None
if self.node.op == 'call_module': if self.node.op == "call_module":
target = self.node.graph.owning_module.get_submodule(self.node.target) target = self.node.graph.owning_module.get_submodule(self.node.target)
elif self.node.op == 'call_function': elif self.node.op == "call_function":
target = self.node.target target = self.node.target
elif self.node.op == 'call_method': elif self.node.op == "call_method":
target = getattr(self.node.args[0]._meta_data.__class__, self.node.target) target = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
else: else:
raise ValueError(f'Unsupported node type: {self.node.op}') raise ValueError(f"Unsupported node type: {self.node.op}")
return target return target
...@@ -221,7 +228,6 @@ class NodeHandler(ABC): ...@@ -221,7 +228,6 @@ class NodeHandler(ABC):
""" """
Define which generators should be used by this NodeHandler object. Define which generators should be used by this NodeHandler object.
""" """
pass
@abstractmethod @abstractmethod
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
...@@ -244,7 +250,6 @@ class NodeHandler(ABC): ...@@ -244,7 +250,6 @@ class NodeHandler(ABC):
"output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data), "output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data),
} }
""" """
pass
class MetaInfoNodeHandler(NodeHandler): class MetaInfoNodeHandler(NodeHandler):
...@@ -278,19 +283,19 @@ class MetaInfoNodeHandler(NodeHandler): ...@@ -278,19 +283,19 @@ class MetaInfoNodeHandler(NodeHandler):
else: else:
logger = get_dist_logger() logger = get_dist_logger()
logger.warning(f'The target function {target} is not patched yet, ') logger.warning(f"The target function {target} is not patched yet, ")
return self.strategies_vector return self.strategies_vector
class ModuleHandler(NodeHandler): class ModuleHandler(NodeHandler):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# set attributes to access module parameters for convenience # set attributes to access module parameters for convenience
assert self.node.graph.owning_module is not None, \ assert (
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.' self.node.graph.owning_module is not None
), f"The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object."
module = self.node.graph.owning_module.get_submodule(self.node.target) module = self.node.graph.owning_module.get_submodule(self.node.target)
named_parameters = list(module.named_parameters(recurse=False)) named_parameters = list(module.named_parameters(recurse=False))
named_buffers = list(module.named_buffers(recurse=False)) named_buffers = list(module.named_buffers(recurse=False))
...@@ -333,6 +338,6 @@ class MetaInfoModuleHandler(ModuleHandler): ...@@ -333,6 +338,6 @@ class MetaInfoModuleHandler(ModuleHandler):
else: else:
logger = get_dist_logger() logger = get_dist_logger()
logger.warning(f'The target function {target} is not patched yet') logger.warning(f"The target function {target} is not patched yet")
return self.strategies_vector return self.strategies_vector
...@@ -3,11 +3,11 @@ from typing import Dict, List ...@@ -3,11 +3,11 @@ from typing import Dict, List
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoModuleHandler, ModuleHandler from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
__all__ = ['NormPoolingHandler'] __all__ = ["NormPoolingHandler"]
@operator_registry.register(torch.nn.MaxPool1d) @operator_registry.register(torch.nn.MaxPool1d)
...@@ -30,9 +30,9 @@ class NormPoolingHandler(MetaInfoModuleHandler): ...@@ -30,9 +30,9 @@ class NormPoolingHandler(MetaInfoModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies # use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process # the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
data=self.node.args[0]._meta_data) )
physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size) physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
......
...@@ -8,7 +8,7 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect ...@@ -8,7 +8,7 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .strategy import OutputGenerator, StrategyGenerator from .strategy import OutputGenerator, StrategyGenerator
__all__ = ['OutputHandler'] __all__ = ["OutputHandler"]
class OutputHandler(NodeHandler): class OutputHandler(NodeHandler):
...@@ -16,8 +16,9 @@ class OutputHandler(NodeHandler): ...@@ -16,8 +16,9 @@ class OutputHandler(NodeHandler):
A OutputHandler which deals with the sharding strategies for Output Node. A OutputHandler which deals with the sharding strategies for Output Node.
""" """
def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, def __init__(
output_option: str) -> None: self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, output_option: str
) -> None:
super().__init__(node, device_mesh, strategies_vector) super().__init__(node, device_mesh, strategies_vector)
self.output_option = output_option self.output_option = output_option
...@@ -35,11 +36,11 @@ class OutputHandler(NodeHandler): ...@@ -35,11 +36,11 @@ class OutputHandler(NodeHandler):
for index, input_node in enumerate(self.predecessor_node): for index, input_node in enumerate(self.predecessor_node):
input_meta_data = input_node._meta_data input_meta_data = input_node._meta_data
physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data) physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data)
name_key = f'input_{index}' name_key = f"input_{index}"
mapping[name_key] = physical_inputs mapping[name_key] = physical_inputs
output_meta_data.append(input_meta_data) output_meta_data.append(input_meta_data)
assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.' assert len(output_meta_data) > 0, f"Output node {self.node} has no input node."
if len(output_meta_data) == 1: if len(output_meta_data) == 1:
output_meta_data = output_meta_data[0] output_meta_data = output_meta_data[0]
else: else:
......
...@@ -7,7 +7,7 @@ from .node_handler import NodeHandler ...@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import PermuteGenerator, StrategyGenerator from .strategy import PermuteGenerator, StrategyGenerator
__all__ = ['PermuteHandler'] __all__ = ["PermuteHandler"]
@operator_registry.register(torch.Tensor.permute) @operator_registry.register(torch.Tensor.permute)
...@@ -34,14 +34,14 @@ class PermuteHandler(NodeHandler): ...@@ -34,14 +34,14 @@ class PermuteHandler(NodeHandler):
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
permute_dims = [] permute_dims = []
if self.node.op == 'call_method': if self.node.op == "call_method":
# torch.Tensor.permute (input, *dims) # torch.Tensor.permute (input, *dims)
for arg in self.node.args: for arg in self.node.args:
if isinstance(arg, torch.fx.Node): if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, int): if isinstance(arg._meta_data, int):
permute_dims.append(arg._meta_data) permute_dims.append(arg._meta_data)
else: else:
assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.' assert isinstance(arg, int), "The argument in permute node should be either type of Node or int."
permute_dims.append(arg) permute_dims.append(arg)
else: else:
# torch.permute (input, dims) # torch.permute (input, dims)
...@@ -51,8 +51,8 @@ class PermuteHandler(NodeHandler): ...@@ -51,8 +51,8 @@ class PermuteHandler(NodeHandler):
permute_dims.extend(arg._meta_data) permute_dims.extend(arg._meta_data)
else: else:
assert isinstance( assert isinstance(
arg, arg, (tuple, list)
(tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].' ), "The argument in permute node should be type of Node, Tuple[int] or List[int]."
permute_dims.extend(arg) permute_dims.extend(arg)
num_dims = self.node._meta_data.dim() num_dims = self.node._meta_data.dim()
...@@ -61,7 +61,7 @@ class PermuteHandler(NodeHandler): ...@@ -61,7 +61,7 @@ class PermuteHandler(NodeHandler):
if permute_dims[i] < 0: if permute_dims[i] < 0:
permute_dims[i] += num_dims permute_dims[i] += num_dims
physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims)) physical_shape_operand = OperationData(name="permute_dims", type=OperationDataType.ARG, data=list(permute_dims))
output_data = self.node._meta_data output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
...@@ -69,7 +69,7 @@ class PermuteHandler(NodeHandler): ...@@ -69,7 +69,7 @@ class PermuteHandler(NodeHandler):
mapping = { mapping = {
"input": physical_input_operand, "input": physical_input_operand,
"permute_dims": physical_shape_operand, "permute_dims": physical_shape_operand,
"output": physical_output_operand "output": physical_output_operand,
} }
return mapping return mapping
...@@ -8,7 +8,7 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect ...@@ -8,7 +8,7 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .strategy import PlaceholderGenerator, StrategyGenerator from .strategy import PlaceholderGenerator, StrategyGenerator
__all__ = ['PlaceholderHandler'] __all__ = ["PlaceholderHandler"]
class PlaceholderHandler(NodeHandler): class PlaceholderHandler(NodeHandler):
...@@ -16,8 +16,9 @@ class PlaceholderHandler(NodeHandler): ...@@ -16,8 +16,9 @@ class PlaceholderHandler(NodeHandler):
A PlaceholderHandler which deals with the sharding strategies for Placeholder Node. A PlaceholderHandler which deals with the sharding strategies for Placeholder Node.
""" """
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, def __init__(
placeholder_option: str) -> None: self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, placeholder_option: str
) -> None:
super().__init__(node, device_mesh, strategies_vector) super().__init__(node, device_mesh, strategies_vector)
self.placeholder_option = placeholder_option self.placeholder_option = placeholder_option
...@@ -25,7 +26,8 @@ class PlaceholderHandler(NodeHandler): ...@@ -25,7 +26,8 @@ class PlaceholderHandler(NodeHandler):
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append( generators.append(
PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)) PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)
)
return generators return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......
class Registry: class Registry:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.store = {} self.store = {}
def register(self, source): def register(self, source):
def wrapper(func): def wrapper(func):
if isinstance(source, (list, tuple)): if isinstance(source, (list, tuple)):
# support register a list of items for this func # support register a list of items for this func
...@@ -18,7 +16,7 @@ class Registry: ...@@ -18,7 +16,7 @@ class Registry:
return wrapper return wrapper
def get(self, source): def get(self, source):
assert source in self.store, f'{source} not found in the {self.name} registry' assert source in self.store, f"{source} not found in the {self.name} registry"
target = self.store[source] target = self.store[source]
return target return target
...@@ -26,4 +24,4 @@ class Registry: ...@@ -26,4 +24,4 @@ class Registry:
return source in self.store return source in self.store
operator_registry = Registry('operator') operator_registry = Registry("operator")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment