Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
......@@ -3,7 +3,6 @@ from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.fx import GraphModule
from torch.fx.graph import Graph
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
......@@ -14,27 +13,32 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pas
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
class ModuleWrapper(nn.Module):
'''
"""
This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict
into the forward function.
'''
def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
'''
"""
def __init__(
self,
module: ColoGraphModule,
sharding_spec_dict: Dict[int, List[ShardingSpec]],
origin_spec_dict: Dict[int, ShardingSpec],
comm_actions_dict: Dict[int, Dict[str, CommAction]],
):
"""
Args:
module: the original module
sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node.
origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor.
comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor.
'''
"""
super(ModuleWrapper, self).__init__()
self.module = module
self.sharding_spec_dict = sharding_spec_dict
......@@ -42,67 +46,68 @@ class ModuleWrapper(nn.Module):
self.comm_actions_dict = comm_actions_dict
def forward(self, *args, **kwargs):
return self.module(*args,
sharding_spec_convert_dict=self.sharding_spec_dict,
origin_node_sharding_spec_dict=self.origin_spec_dict,
comm_actions_dict=self.comm_actions_dict,
**kwargs)
return self.module(
*args,
sharding_spec_convert_dict=self.sharding_spec_dict,
origin_node_sharding_spec_dict=self.origin_spec_dict,
comm_actions_dict=self.comm_actions_dict,
**kwargs,
)
def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable):
'''
"""
This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func.
'''
"""
# TODO: implement this function
pass
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
'''
"""
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
from the alpha_beta_dict. These two values will be used to estimate the communication cost.
'''
"""
# TODO: implement this function
pass
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
shard_option: str):
'''
def build_strategy_constructor(
graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, shard_option: str
):
"""
This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler.
'''
if solver_preference == 'standard':
"""
if solver_preference == "standard":
solver_preference = SolverPerference.STANDARD
elif solver_preference == 'tp':
elif solver_preference == "tp":
solver_preference = SolverPerference.TP
elif solver_preference == 'dp':
elif solver_preference == "dp":
solver_preference = SolverPerference.DP
else:
raise ValueError(f'Invalid solver_preference: {solver_preference}')
raise ValueError(f"Invalid solver_preference: {solver_preference}")
if dataloader_option == 'replicated':
if dataloader_option == "replicated":
dataloader_option = DataloaderOption.REPLICATED
elif dataloader_option == 'distributed':
elif dataloader_option == "distributed":
dataloader_option = DataloaderOption.DISTRIBUTED
else:
raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
raise ValueError(f"Invalid dataloader_option: {dataloader_option}")
if shard_option == 'standard':
if shard_option == "standard":
shard_option = ShardOption.STANDARD
elif shard_option == 'shard':
elif shard_option == "shard":
shard_option = ShardOption.SHARD
elif shard_option == 'shard_last_axis':
elif shard_option == "shard_last_axis":
shard_option = ShardOption.SHARD_LAST_AXIS
elif shard_option == 'full_shard':
elif shard_option == "full_shard":
shard_option = ShardOption.FULL_SHARD
else:
raise ValueError(f'Invalid shard_option: {shard_option}')
raise ValueError(f"Invalid shard_option: {shard_option}")
solver_options = SolverOptions(solver_perference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
solver_options = SolverOptions(
solver_perference=solver_preference, dataloader_option=dataloader_option, shard_option=shard_option
)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
......@@ -110,10 +115,10 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_pre
def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
'''
"""
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
'''
"""
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
......@@ -127,23 +132,23 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
return solution
def transform_to_sharded_model(gm: ColoGraphModule,
meta_args: Dict,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap: bool = False):
'''
def transform_to_sharded_model(
gm: ColoGraphModule,
meta_args: Dict,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap: bool = False,
):
"""
This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass.
'''
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm,
solution,
device_mesh,
strategies_constructor,
overlap=overlap)
"""
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh, strategies_constructor, overlap=overlap
)
gm = runtime_apply_pass(gm)
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
gm.recompile()
......@@ -152,12 +157,14 @@ def transform_to_sharded_model(gm: ColoGraphModule,
return gm, sharding_spec_dicts
def initialize_device_mesh(world_size: int = -1,
physical_devices: List[int] = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None):
'''
def initialize_device_mesh(
world_size: int = -1,
physical_devices: List[int] = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
):
"""
This method is used to initialize the device mesh.
Args:
......@@ -170,7 +177,7 @@ def initialize_device_mesh(world_size: int = -1,
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
'''
"""
# if world_size is not set, use the world size from torch.distributed
if world_size == -1:
world_size = dist.get_world_size()
......@@ -201,27 +208,31 @@ def initialize_device_mesh(world_size: int = -1,
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
logical_mesh_id=logical_mesh_id,
mesh_alpha=mesh_alpha,
mesh_beta=mesh_beta,
init_process_group=True)
device_mesh = DeviceMesh(
physical_mesh_id=physical_mesh,
logical_mesh_id=logical_mesh_id,
mesh_alpha=mesh_alpha,
mesh_beta=mesh_beta,
init_process_group=True,
)
return device_mesh
def initialize_model(model: nn.Module,
meta_args: Dict[str, torch.Tensor],
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
overlap: bool = False,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
return_solution: bool = False):
'''
def initialize_model(
model: nn.Module,
meta_args: Dict[str, torch.Tensor],
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
overlap: bool = False,
solver_preference: str = "standard",
dataloader_option: str = "replicated",
shard_option: str = "standard",
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
return_solution: bool = False,
):
"""
This method is used to initialize the sharded model which could be used as normal pytorch model.
Args:
......@@ -246,7 +257,7 @@ def initialize_model(model: nn.Module,
return_solution(optional): if the return_solution is True, the solution will be returned. The returned
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
'''
"""
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
graph = tracer.trace(root=model, meta_args=meta_args)
......@@ -256,11 +267,13 @@ def initialize_model(model: nn.Module,
shape_prop_pass(gm, *meta_args.values())
gm.recompile()
strategies_constructor = build_strategy_constructor(graph,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
strategies_constructor = build_strategy_constructor(
graph,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option,
)
if load_solver_solution:
solution = torch.load(solution_path)
else:
......@@ -268,8 +281,9 @@ def initialize_model(model: nn.Module,
if save_solver_solution:
torch.save(solution, solution_path)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor,
overlap)
gm, sharding_spec_dicts = transform_to_sharded_model(
gm, meta_args, solution, device_mesh, strategies_constructor, overlap
)
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
......@@ -277,28 +291,30 @@ def initialize_model(model: nn.Module,
solution_to_return = []
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}')
solution_to_return.append(f"{node.name} {node.strategies_vector[solution[index]].name}")
return model_to_return, solution_to_return
else:
return model_to_return
def autoparallelize(model: nn.Module,
meta_args: Dict[str, torch.Tensor] = None,
data_loader: torch.utils.data.DataLoader = None,
data_process_func: callable = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
return_solution: bool = False,
memory_budget: float = -1.0):
'''
def autoparallelize(
model: nn.Module,
meta_args: Dict[str, torch.Tensor] = None,
data_loader: torch.utils.data.DataLoader = None,
data_process_func: callable = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
solver_preference: str = "standard",
dataloader_option: str = "replicated",
shard_option: str = "standard",
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
return_solution: bool = False,
memory_budget: float = -1.0,
):
"""
This method is used to initialize the device mesh, extract the meta_args, and
use them to create a sharded model.
......@@ -329,24 +345,26 @@ def autoparallelize(model: nn.Module,
return_solution(optional): if the return_solution is True, the solution will be returned.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
'''
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
logical_mesh_shape=logical_mesh_shape,
logical_mesh_id=logical_mesh_id)
"""
device_mesh = initialize_device_mesh(
alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape, logical_mesh_id=logical_mesh_id
)
if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
rst_to_unpack = initialize_model(model,
meta_args,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solution_path=solver_solution_path,
return_solution=return_solution,
memory_budget=memory_budget)
rst_to_unpack = initialize_model(
model,
meta_args,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solution_path=solver_solution_path,
return_solution=return_solution,
memory_budget=memory_budget,
)
if return_solution:
model, solution = rst_to_unpack
......
......@@ -25,11 +25,33 @@ from .view_handler import ViewHandler
from .where_handler import WhereHandler
__all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler',
'SplitHandler'
"LinearFunctionHandler",
"LinearModuleHandler",
"BMMFunctionHandler",
"AddBMMFunctionHandler",
"LayerNormModuleHandler",
"BatchNormModuleHandler",
"ConvModuleHandler",
"ConvFunctionHandler",
"UnaryElementwiseHandler",
"DefaultReshapeHandler",
"PlaceholderHandler",
"OutputHandler",
"WhereHandler",
"NormPoolingHandler",
"BinaryElementwiseHandler",
"MatMulHandler",
"operator_registry",
"ADDMMFunctionHandler",
"GetItemHandler",
"GetattrHandler",
"ViewHandler",
"PermuteHandler",
"TensorConstructorHandler",
"EmbeddingModuleHandler",
"EmbeddingFunctionHandler",
"SumHandler",
"SoftmaxHandler",
"TransposeHandler",
"SplitHandler",
]
......@@ -2,15 +2,13 @@ from typing import Dict, List, Union
import torch
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
__all__ = ['ADDMMFunctionHandler']
__all__ = ["ADDMMFunctionHandler"]
@operator_registry.register(torch.addmm)
......@@ -30,25 +28,26 @@ class ADDMMFunctionHandler(NodeHandler):
return data_type
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# input operand
input_data = self.node.args[1]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[1]),
type=self._infer_op_data_type(input_data),
data=input_data)
physical_input_operand = OperationData(
name=str(self.node.args[1]), type=self._infer_op_data_type(input_data), data=input_data
)
# other operand
other_data = self.node.args[2]._meta_data
physical_other_operand = OperationData(name=str(self.node.args[2]),
type=self._infer_op_data_type(other_data),
data=other_data)
physical_other_operand = OperationData(
name=str(self.node.args[2]), type=self._infer_op_data_type(other_data), data=other_data
)
# bias physical shape
bias_logical_shape = self.node._meta_data.shape
bias_data = self.node.args[0]._meta_data
physical_bias_operand = OperationData(name=str(self.node.args[0]),
type=self._infer_op_data_type(bias_data),
data=bias_data,
logical_shape=bias_logical_shape)
physical_bias_operand = OperationData(
name=str(self.node.args[0]),
type=self._infer_op_data_type(bias_data),
data=bias_data,
logical_shape=bias_logical_shape,
)
# output
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
......@@ -57,7 +56,7 @@ class ADDMMFunctionHandler(NodeHandler):
"input": physical_input_operand,
"other": physical_other_operand,
"output": physical_output,
'bias': physical_bias_operand
"bias": physical_bias_operand,
}
return mapping
......@@ -66,26 +65,27 @@ class ADDMMFunctionHandler(NodeHandler):
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm'))
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="addmm")
)
return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
bias_op_data = op_data_mapping['bias']
bias_op_data = op_data_mapping["bias"]
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
bias_sharding_spec, bias_logical_shape, bias_physical_shape
)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=bias_op_data,
sharding_spec=bias_sharding_spec)
comm_action = comm_actions_for_oprands(
node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec
)
strategy.communication_actions[bias_op_data] = comm_action
return strategy
......@@ -2,12 +2,12 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
from .node_handler import MetaInfoModuleHandler, ModuleHandler
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
__all__ = ['BatchNormModuleHandler']
__all__ = ["BatchNormModuleHandler"]
@operator_registry.register(torch.nn.BatchNorm1d)
......@@ -27,30 +27,37 @@ class BatchNormModuleHandler(MetaInfoModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'],
logical_shape=self.named_parameters['weight'].shape)
physical_input_operand = OperationData(
name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
)
physical_other_operand = OperationData(
name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters["weight"],
logical_shape=self.named_parameters["weight"].shape,
)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
physical_running_mean_operand = OperationData(name="running_mean",
type=OperationDataType.BUFFER,
data=self.named_buffers['running_mean'],
logical_shape=self.named_buffers['running_mean'].shape)
physical_running_mean_operand = OperationData(
name="running_mean",
type=OperationDataType.BUFFER,
data=self.named_buffers["running_mean"],
logical_shape=self.named_buffers["running_mean"].shape,
)
physical_running_var_operand = OperationData(name="running_var",
type=OperationDataType.BUFFER,
data=self.named_buffers['running_var'],
logical_shape=self.named_buffers['running_var'].shape)
physical_running_var_operand = OperationData(
name="running_var",
type=OperationDataType.BUFFER,
data=self.named_buffers["running_var"],
logical_shape=self.named_buffers["running_var"].shape,
)
physical_num_batches_tracked_operand = OperationData(
name="num_batches_tracked",
type=OperationDataType.BUFFER,
data=self.named_buffers['num_batches_tracked'],
logical_shape=self.named_buffers['num_batches_tracked'].shape)
data=self.named_buffers["num_batches_tracked"],
logical_shape=self.named_buffers["num_batches_tracked"].shape,
)
mapping = {
"input": physical_input_operand,
......@@ -58,12 +65,12 @@ class BatchNormModuleHandler(MetaInfoModuleHandler):
"output": physical_output,
"running_mean": physical_running_mean_operand,
"running_var": physical_running_var_operand,
"num_batches_tracked": physical_num_batches_tracked_operand
"num_batches_tracked": physical_num_batches_tracked_operand,
}
if self.named_parameters['bias'] is not None:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
mapping['bias'] = physical_bias_operand
if self.named_parameters["bias"] is not None:
physical_bias_operand = OperationData(
name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
)
mapping["bias"] = physical_bias_operand
return mapping
......@@ -4,15 +4,14 @@ import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..constants import BCAST_FUNC_OP
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
__all__ = ['BinaryElementwiseHandler']
__all__ = ["BinaryElementwiseHandler"]
@operator_registry.register(BCAST_FUNC_OP)
......@@ -38,7 +37,7 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
# The meta_data of node type argument could also possibly be a non-tensor object.
if not isinstance(meta_data, torch.Tensor):
assert isinstance(meta_data, (int, float))
meta_data = torch.Tensor([meta_data]).to('meta')
meta_data = torch.Tensor([meta_data]).to("meta")
non_tensor = True
else:
......@@ -46,7 +45,7 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
# but we can deem it as meta data
# as it won't affect the strategy generation
assert isinstance(self.node.args[idx], (int, float))
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
meta_data = torch.Tensor([self.node.args[idx]]).to("meta")
non_tensor = True
return meta_data, non_tensor
......@@ -58,24 +57,27 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
# and filter the non-tensor op_data in post_process.
self.non_tensor_list = []
# assert False
input_op_data = OperationData(name=str(self.node.args[0]),
type=_get_op_data_type(input_meta_data),
data=input_meta_data,
logical_shape=bcast_shape)
other_op_data = OperationData(name=str(self.node.args[1]),
type=_get_op_data_type(other_meta_data),
data=other_meta_data,
logical_shape=bcast_shape)
output_op_data = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=bcast_shape)
input_op_data = OperationData(
name=str(self.node.args[0]),
type=_get_op_data_type(input_meta_data),
data=input_meta_data,
logical_shape=bcast_shape,
)
other_op_data = OperationData(
name=str(self.node.args[1]),
type=_get_op_data_type(other_meta_data),
data=other_meta_data,
logical_shape=bcast_shape,
)
output_op_data = OperationData(
name=str(self.node), type=OperationDataType.OUTPUT, data=output_meta_data, logical_shape=bcast_shape
)
if non_tensor_input:
self.non_tensor_list.append(input_op_data)
if non_tensor_other:
self.non_tensor_list.append(other_op_data)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data}
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
......@@ -100,14 +102,14 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
logical_shape = op_data.logical_shape
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
sharding_spec, logical_shape, physical_shape)
sharding_spec, logical_shape, physical_shape
)
strategy.sharding_specs[op_data] = sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=op_data,
sharding_spec=sharding_spec)
comm_action = comm_actions_for_oprands(
node=self.node, removed_dims=removed_dims, op_data=op_data, sharding_spec=sharding_spec
)
strategy.communication_actions[op_data] = comm_action
return strategy
......@@ -2,15 +2,13 @@ from typing import Dict, List, Union
import torch
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler']
__all__ = ["BMMFunctionHandler", "AddBMMFunctionHandler"]
def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
......@@ -19,14 +17,14 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
node handler to reduce code redundancy.
"""
# input operand
physical_input_operand = OperationData(name=str(node.args[input_idx]),
type=OperationDataType.ARG,
data=node.args[input_idx]._meta_data)
physical_input_operand = OperationData(
name=str(node.args[input_idx]), type=OperationDataType.ARG, data=node.args[input_idx]._meta_data
)
# other operand
physical_other_operand = OperationData(name=str(node.args[other_idx]),
type=OperationDataType.ARG,
data=node.args[other_idx]._meta_data)
physical_other_operand = OperationData(
name=str(node.args[other_idx]), type=OperationDataType.ARG, data=node.args[other_idx]._meta_data
)
# output
physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data)
......@@ -35,11 +33,13 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
if bias_idx is not None:
# bias physical shape
bias_logical_shape = node._meta_data.shape
physical_bias_operand = OperationData(name=str(node.args[bias_idx]),
type=OperationDataType.ARG,
data=node.args[bias_idx]._meta_data,
logical_shape=bias_logical_shape)
mapping['bias'] = physical_bias_operand
physical_bias_operand = OperationData(
name=str(node.args[bias_idx]),
type=OperationDataType.ARG,
data=node.args[bias_idx]._meta_data,
logical_shape=bias_logical_shape,
)
mapping["bias"] = physical_bias_operand
return mapping
......@@ -91,20 +91,20 @@ class AddBMMFunctionHandler(NodeHandler):
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
if 'bias' in op_data_mapping:
bias_op_data = op_data_mapping['bias']
if "bias" in op_data_mapping:
bias_op_data = op_data_mapping["bias"]
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
bias_sharding_spec, bias_logical_shape, bias_physical_shape
)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=bias_op_data,
sharding_spec=bias_sharding_spec)
comm_action = comm_actions_for_oprands(
node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec
)
strategy.communication_actions[bias_op_data] = comm_action
return strategy
......@@ -3,13 +3,13 @@ from typing import Dict, List
import torch
import torch.nn.functional as F
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import transpose_partition_dim
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
__all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
__all__ = ["ConvModuleHandler", "ConvFunctionHandler"]
@operator_registry.register(torch.nn.Conv1d)
......@@ -29,25 +29,29 @@ class ConvModuleHandler(MetaInfoModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_input_operand = OperationData(
name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
)
logical_shape_for_weight = list(self.named_parameters["weight"].shape)
logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
1], logical_shape_for_weight[0]
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'],
logical_shape=torch.Size(logical_shape_for_weight))
logical_shape_for_weight[0], logical_shape_for_weight[1] = (
logical_shape_for_weight[1],
logical_shape_for_weight[0],
)
physical_other_operand = OperationData(
name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters["weight"],
logical_shape=torch.Size(logical_shape_for_weight),
)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if "bias" in self.named_parameters:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
mapping['bias'] = physical_bias_operand
physical_bias_operand = OperationData(
name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
)
mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
......@@ -77,9 +81,9 @@ class ConvFunctionHandler(MetaInfoNodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_input_operand = OperationData(
name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
)
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
......@@ -88,26 +92,30 @@ class ConvFunctionHandler(MetaInfoNodeHandler):
data_type = OperationDataType.ARG
logical_shape_for_weight = list(self.node.args[1]._meta_data.shape)
logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
1], logical_shape_for_weight[0]
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=data_type,
data=self.node.args[1]._meta_data,
logical_shape=torch.Size(logical_shape_for_weight))
logical_shape_for_weight[0], logical_shape_for_weight[1] = (
logical_shape_for_weight[1],
logical_shape_for_weight[0],
)
physical_other_operand = OperationData(
name=str(self.node.args[1]),
type=data_type,
data=self.node.args[1]._meta_data,
logical_shape=torch.Size(logical_shape_for_weight),
)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None:
if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None:
# check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
type=data_type,
data=self.node.kwargs["bias"]._meta_data)
mapping['bias'] = physical_bias_operand
physical_bias_operand = OperationData(
name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data
)
mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
......
......@@ -3,11 +3,11 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import DefaultReshapeGenerator, StrategyGenerator
__all__ = ['DefaultReshapeHandler']
__all__ = ["DefaultReshapeHandler"]
@operator_registry.register(torch.flatten)
......@@ -54,17 +54,15 @@ class DefaultReshapeHandler(MetaInfoNodeHandler):
input_data = self.node.args[0]._meta_data
input_logical_shape = self.infer_logical_shape(input_data)
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=data_type,
data=input_data,
logical_shape=input_logical_shape)
physical_input_operand = OperationData(
name=str(self.node.args[0]), type=data_type, data=input_data, logical_shape=input_logical_shape
)
output_data = self.node._meta_data
output_logical_shape = self.infer_logical_shape(output_data)
physical_output = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_data,
logical_shape=output_logical_shape)
physical_output = OperationData(
name=str(self.node), type=OperationDataType.OUTPUT, data=output_data, logical_shape=output_logical_shape
)
mapping = {"input": physical_input_operand, "output": physical_output}
......
......@@ -12,11 +12,12 @@ from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import EmbeddingStrategyGenerator, StrategyGenerator
__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler']
__all__ = ["EmbeddingModuleHandler", "EmbeddingFunctionHandler"]
def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str,
output_name: str) -> List[ShardingStrategy]:
def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
strategy: ShardingStrategy, input_name: str, output_name: str
) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output
of the embedding operation.
......@@ -56,27 +57,31 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy:
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(
sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True,
)
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {0: i}
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping=dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
update_partition_dim(
sharding_spec=output_sharding_spec,
dim_mapping=dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True,
)
strategy_copy.name = f'{strategy.name}_{i}'
strategy_copy.name = f"{strategy.name}_{i}"
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}"
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
......@@ -87,20 +92,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy:
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(
sharding_spec=input_sharding_spec, dim_mapping={}, physical_shape=input_op_data.data.shape, inplace=True
)
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {}
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping=dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
update_partition_dim(
sharding_spec=output_sharding_spec,
dim_mapping=dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True,
)
sharding_strategies.append(strategy_copy)
return sharding_strategies
......@@ -125,14 +131,16 @@ class EmbeddingModuleHandler(ModuleHandler):
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=input_meta_data,
logical_shape=input_logical_shape)
physical_input_operand = OperationData(
name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=input_meta_data,
logical_shape=input_logical_shape,
)
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'])
physical_other_operand = OperationData(
name="weight", type=OperationDataType.PARAM, data=self.named_parameters["weight"]
)
# Same as input, in nn.Embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
......@@ -141,10 +149,12 @@ class EmbeddingModuleHandler(ModuleHandler):
# Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=output_logical_shape)
physical_output = OperationData(
name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
......@@ -157,10 +167,9 @@ class EmbeddingModuleHandler(ModuleHandler):
# create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
input_name=str(
self.node.args[0]),
output_name=str(self.node))
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
)
return strategies
......@@ -183,10 +192,12 @@ class EmbeddingFunctionHandler(NodeHandler):
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data,
logical_shape=input_logical_shape)
physical_input_operand = OperationData(
name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data,
logical_shape=input_logical_shape,
)
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
......@@ -194,9 +205,9 @@ class EmbeddingFunctionHandler(NodeHandler):
else:
data_type = OperationDataType.ARG
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=data_type,
data=self.node.args[1]._meta_data)
physical_other_operand = OperationData(
name=str(self.node.args[1]), type=data_type, data=self.node.args[1]._meta_data
)
# Same as input, in F.embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
......@@ -223,8 +234,7 @@ class EmbeddingFunctionHandler(NodeHandler):
# create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
input_name=str(
self.node.args[0]),
output_name=str(self.node))
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
)
return strategies
......@@ -4,7 +4,7 @@ from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .strategy import GetattrGenerator, StrategyGenerator
__all__ = ['GetattrHandler']
__all__ = ["GetattrHandler"]
class GetattrHandler(NodeHandler):
......
......@@ -8,7 +8,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
__all__ = ['GetItemHandler']
__all__ = ["GetItemHandler"]
@operator_registry.register(operator.getitem)
......@@ -30,9 +30,9 @@ class GetItemHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_input_operand = OperationData(
name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
)
physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
......
......@@ -3,11 +3,11 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoModuleHandler, ModuleHandler
from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry
from .strategy import LayerNormGenerator, StrategyGenerator
__all__ = ['LayerNormModuleHandler']
__all__ = ["LayerNormModuleHandler"]
@operator_registry.register(torch.nn.LayerNorm)
......@@ -25,20 +25,22 @@ class LayerNormModuleHandler(MetaInfoModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'],
logical_shape=self.named_parameters['weight'].shape)
physical_input_operand = OperationData(
name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
)
physical_other_operand = OperationData(
name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters["weight"],
logical_shape=self.named_parameters["weight"].shape,
)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.named_parameters['bias'] is not None:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
mapping['bias'] = physical_bias_operand
if self.named_parameters["bias"] is not None:
physical_bias_operand = OperationData(
name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
)
mapping["bias"] = physical_bias_operand
return mapping
......@@ -3,24 +3,21 @@ from typing import Dict, List, Union
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import (
check_sharding_spec_validity,
transpose_partition_dim,
update_partition_dim,
)
from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
__all__ = ["LinearModuleHandler", "LinearFunctionHandler"]
def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy,
weight_name: str) -> ShardingStrategy:
def _update_sharding_spec_for_transposed_weight_for_linear(
strategy: ShardingStrategy, weight_name: str
) -> ShardingStrategy:
"""
This function is a helper function used by both module node handler and function node handler. This function will
convert the sharding spec for the transposed weight to the correct partition spec.
......@@ -32,16 +29,17 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
# switch the dimensions of the transposed weight
sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape[0] == op_data.data.shape[1] and \
op_data.logical_shape[1] == op_data.data.shape[0], \
"Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
assert (
op_data.logical_shape[0] == op_data.data.shape[1] and op_data.logical_shape[1] == op_data.data.shape[0]
), "Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
dim_size = len(op_data.logical_shape)
transpose_partition_dim(sharding_spec, 0, dim_size - 1)
return strategy
def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str,
output_name: str) -> List[ShardingStrategy]:
def _convert_logical_sharding_to_physical_sharding_spec_for_linear(
strategy: ShardingStrategy, input_name: str, output_name: str
) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output
should have the same sharding spec.
......@@ -99,22 +97,26 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
input_dim_mapping = {0: i}
input_dim_mapping.update(input_last_dim_mapping)
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping=input_dim_mapping,
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(
sharding_spec=input_sharding_spec,
dim_mapping=input_dim_mapping,
physical_shape=input_op_data.data.shape,
inplace=True,
)
output_dim_mapping = {0: i}
output_dim_mapping.update(output_last_dim_mapping)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping=output_dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
strategy_copy.name = f'{strategy.name}_{i}'
update_partition_dim(
sharding_spec=output_sharding_spec,
dim_mapping=output_dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True,
)
strategy_copy.name = f"{strategy.name}_{i}"
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}"
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
......@@ -127,17 +129,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
# after updating, the logical shape will be replaced by the physical shape
input_dim_mapping = {}
input_dim_mapping.update(input_last_dim_mapping)
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping=input_dim_mapping,
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(
sharding_spec=input_sharding_spec,
dim_mapping=input_dim_mapping,
physical_shape=input_op_data.data.shape,
inplace=True,
)
output_dim_mapping = {}
output_dim_mapping.update(output_last_dim_mapping)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping=output_dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
update_partition_dim(
sharding_spec=output_sharding_spec,
dim_mapping=output_dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True,
)
sharding_strategies.append(strategy_copy)
return sharding_strategies
......@@ -152,10 +158,13 @@ class LinearModuleHandler(MetaInfoModuleHandler):
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping,
self.device_mesh,
linear_projection_type='linear',
solver_perference=self.solver_perference))
LinearProjectionStrategyGenerator(
op_data_mapping,
self.device_mesh,
linear_projection_type="linear",
solver_perference=self.solver_perference,
)
)
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......@@ -163,28 +172,34 @@ class LinearModuleHandler(MetaInfoModuleHandler):
# the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=input_meta_data,
logical_shape=input_logical_shape)
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'],
logical_shape=self.named_parameters['weight'].shape[::-1])
physical_input_operand = OperationData(
name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=input_meta_data,
logical_shape=input_logical_shape,
)
physical_other_operand = OperationData(
name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters["weight"],
logical_shape=self.named_parameters["weight"].shape[::-1],
)
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=output_logical_shape)
physical_output = OperationData(
name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if 'bias' in self.named_parameters is not None:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
mapping['bias'] = physical_bias_operand
if "bias" in self.named_parameters is not None:
physical_bias_operand = OperationData(
name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
)
mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
......@@ -194,14 +209,14 @@ class LinearModuleHandler(MetaInfoModuleHandler):
2. the input and output sharding specs are updated to physical shape.
"""
# switch the dimensions of the transposed weight
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight')
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name="weight")
# create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
input_name=str(self.node.args[0]),
output_name=str(self.node))
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(
strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
)
return strategies
......@@ -215,7 +230,8 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear")
)
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......@@ -223,10 +239,12 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
# the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data,
logical_shape=input_logical_shape)
physical_input_operand = OperationData(
name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data,
logical_shape=input_logical_shape,
)
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
......@@ -234,10 +252,12 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
else:
data_type = OperationDataType.ARG
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=data_type,
data=self.node.args[1]._meta_data,
logical_shape=self.node.args[1]._meta_data.shape[::-1])
physical_other_operand = OperationData(
name=str(self.node.args[1]),
type=data_type,
data=self.node.args[1]._meta_data,
logical_shape=self.node.args[1]._meta_data.shape[::-1],
)
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(
......@@ -249,27 +269,28 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None:
if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None:
# check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
type=data_type,
data=self.node.kwargs["bias"]._meta_data)
mapping['bias'] = physical_bias_operand
physical_bias_operand = OperationData(
name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data
)
mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
# switch the dimensions of the transposed weight
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
weight_name=str(self.node.args[1]))
strategy = _update_sharding_spec_for_transposed_weight_for_linear(
strategy=strategy, weight_name=str(self.node.args[1])
)
# create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
input_name=str(self.node.args[0]),
output_name=str(self.node))
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(
strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
)
return strategies
......@@ -16,7 +16,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import (
BatchedMatMulStrategyGenerator,
......@@ -37,6 +37,7 @@ class MatMulType(Enum):
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
"""
DOT = 0
MM = 1
MV = 2
......@@ -92,26 +93,26 @@ class Padder(BmmTransform):
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = deepcopy(shape_mapping)
input_shape = mapping_copy['input']
other_shape = mapping_copy['other']
input_shape = mapping_copy["input"]
other_shape = mapping_copy["other"]
if len(input_shape) == 1:
# if the input is a 1D tensor, 1 is prepended to its shape
# and it will be removed afterwards
input_shape.insert(0, 1)
self.padded_dim_mapping['input'] = -2
self.padded_dim_mapping['output'] = -2
self.padded_dim_mapping["input"] = -2
self.padded_dim_mapping["output"] = -2
elif len(other_shape) == 1:
# if the other is a 1D tensor, 1 is appended to its shape
# and it will be removed afterwards
other_shape = other_shape.append(1)
self.padded_dim_mapping['other'] = -1
self.padded_dim_mapping['output'] = -1
self.padded_dim_mapping["other"] = -1
self.padded_dim_mapping["output"] = -1
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
input_op_data = op_data_mapping['input']
other_op_data = op_data_mapping['other']
op_data_mapping["input"]
op_data_mapping["other"]
def _remove_padded_dim(key, strategy):
op_data = op_data_mapping[key]
......@@ -131,7 +132,7 @@ class Padder(BmmTransform):
# compute unpadded tensor shape
tensor_shape.pop(padded_dim)
assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}'
assert tensor_shape == list(op_data.data.shape), f"{tensor_shape} vs {list(op_data.data.shape)}"
# update sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
......@@ -142,15 +143,15 @@ class Padder(BmmTransform):
strategy_copy = strategy.clone()
# only one of input and other will be padded
if 'input' in self.padded_dim_mapping:
_remove_padded_dim('input', strategy_copy)
_remove_padded_dim('output', strategy_copy)
elif 'other' in self.padded_dim_mapping:
_remove_padded_dim('other', strategy_copy)
_remove_padded_dim('output', strategy_copy)
if "input" in self.padded_dim_mapping:
_remove_padded_dim("input", strategy_copy)
_remove_padded_dim("output", strategy_copy)
elif "other" in self.padded_dim_mapping:
_remove_padded_dim("other", strategy_copy)
_remove_padded_dim("output", strategy_copy)
strategies.append(strategy_copy)
except ShardingSpecException as e:
except ShardingSpecException:
pass
return strategies
......@@ -167,8 +168,8 @@ class Broadcaster(BmmTransform):
mapping_copy = shape_mapping.copy()
# get shapes
input_shape = mapping_copy['input']
other_shape = mapping_copy['other']
input_shape = mapping_copy["input"]
other_shape = mapping_copy["other"]
# sanity check
assert len(input_shape) > 1 and len(other_shape) > 1
......@@ -179,16 +180,16 @@ class Broadcaster(BmmTransform):
# store the broadcast dim info
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
self.broadcast_dim_info['input'] = input_broadcast_dim_info
self.broadcast_dim_info['other'] = other_broadcast_dim_info
self.broadcast_dim_info["input"] = input_broadcast_dim_info
self.broadcast_dim_info["other"] = other_broadcast_dim_info
# create the full logical shape
input_shape = bcast_non_matrix_dims + input_shape[-2:]
other_shape = bcast_non_matrix_dims + other_shape[-2:]
assert len(input_shape) == len(other_shape)
mapping_copy['input'] = input_shape
mapping_copy['other'] = other_shape
mapping_copy["input"] = input_shape
mapping_copy["other"] = other_shape
return mapping_copy
......@@ -216,17 +217,18 @@ class Broadcaster(BmmTransform):
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec=sharding_spec,
logical_shape=sharding_spec.entire_shape,
physical_shape=tensor_shape_before_broadcast)
physical_shape=tensor_shape_before_broadcast,
)
strategy.sharding_specs[op_data] = physical_sharding_spec
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
_remove_sharding_on_broadcast_dim('input', strategy_copy)
_remove_sharding_on_broadcast_dim('other', strategy_copy)
_remove_sharding_on_broadcast_dim("input", strategy_copy)
_remove_sharding_on_broadcast_dim("other", strategy_copy)
strategies.append(strategy_copy)
except ShardingSpecException as e:
except ShardingSpecException:
pass
return strategies
......@@ -241,20 +243,20 @@ class Viewer(BmmTransform):
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
self.batch_dims_before_view = list(mapping_copy['input'][:-2])
self.batch_dims_before_view = list(mapping_copy["input"][:-2])
# get shapes
input_shape = shape_mapping['input']
other_shape = shape_mapping['other']
input_shape = shape_mapping["input"]
other_shape = shape_mapping["other"]
# view to 3d tensor
assert len(input_shape) >= 3 and len(other_shape) >= 3
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
output_shape = input_shape[:2] + other_shape[2:]
mapping_copy['input'] = input_shape
mapping_copy['other'] = other_shape
mapping_copy['output'] = output_shape
mapping_copy["input"] = input_shape
mapping_copy["other"] = other_shape
mapping_copy["output"] = output_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
......@@ -291,11 +293,11 @@ class Viewer(BmmTransform):
# create a new strategy
strategy_copy = strategy.clone()
try:
_update_sharding_spec('input', strategy_copy, i)
_update_sharding_spec('other', strategy_copy, i)
_update_sharding_spec('output', strategy_copy, i)
_update_sharding_spec("input", strategy_copy, i)
_update_sharding_spec("other", strategy_copy, i)
_update_sharding_spec("output", strategy_copy, i)
strategies.append(strategy_copy)
except ShardingSpecException as e:
except ShardingSpecException:
continue
return strategies
......@@ -312,14 +314,14 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms):
3. reshape to 3 dimensions
"""
shape_mapping = {'input': input_shape, 'other': other_shape}
shape_mapping = {"input": input_shape, "other": other_shape}
for transform in transforms:
shape_mapping = transform.apply(shape_mapping)
input_shape = shape_mapping.get('input', None)
other_shape = shape_mapping.get('other', None)
output_shape = shape_mapping.get('output', None)
input_shape = shape_mapping.get("input", None)
other_shape = shape_mapping.get("other", None)
output_shape = shape_mapping.get("output", None)
return input_shape, other_shape, output_shape
......@@ -364,7 +366,8 @@ class MatMulHandler(MetaInfoNodeHandler):
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MM:
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear")
)
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......@@ -372,7 +375,7 @@ class MatMulHandler(MetaInfoNodeHandler):
MatMulType.DOT: self._get_logical_shape_for_dot,
MatMulType.MM: self._get_logical_shape_for_mm,
MatMulType.MV: self._get_logical_shape_for_mv,
MatMulType.BMM: self._get_logical_shape_for_bmm
MatMulType.BMM: self._get_logical_shape_for_bmm,
}
logical_shapes = logical_shape_func[self.matmul_type]()
op_data_mapping = self._get_op_data_mapping(*logical_shapes)
......@@ -390,20 +393,26 @@ class MatMulHandler(MetaInfoNodeHandler):
output_logical_shape = torch.Size(output_logical_shape)
# create op data
input_op_data = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.input_meta_data,
logical_shape=input_logical_shape)
other_op_data = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.other_meta_data,
logical_shape=other_logical_shape)
output_op_data = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.output_meta_data,
logical_shape=output_logical_shape)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
input_op_data = OperationData(
name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.input_meta_data,
logical_shape=input_logical_shape,
)
other_op_data = OperationData(
name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.other_meta_data,
logical_shape=other_logical_shape,
)
output_op_data = OperationData(
name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.output_meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data}
return mapping
def _get_logical_shape_for_dot(self):
......@@ -460,9 +469,11 @@ class MatMulHandler(MetaInfoNodeHandler):
dim_partition_dict[0] = shard
# re-init the sharding spec
input_sharding_spec.__init__(input_sharding_spec.device_mesh,
entire_shape=input_physical_shape,
dim_partition_dict=dim_partition_dict)
input_sharding_spec.__init__(
input_sharding_spec.device_mesh,
entire_shape=input_physical_shape,
dim_partition_dict=dim_partition_dict,
)
return strategy
else:
return strategy
......@@ -481,7 +492,8 @@ class MatMulHandler(MetaInfoNodeHandler):
recovered_stragies.extend(output)
else:
raise TypeError(
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
f"Found unexpected output type {type(output)} from the recover method of BmmTransform"
)
strategies = recovered_stragies
for index, strategies in enumerate(strategies):
strategies.name = f"{strategies.name}_{index}"
......
......@@ -8,7 +8,6 @@ from colossalai.auto_parallel.meta_profiler.shard_metainfo import ShardMetaInfo,
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingSpec,
ShardingStrategy,
StrategiesVector,
......@@ -23,21 +22,23 @@ from .strategy import StrategyGenerator
class NodeHandler(ABC):
'''
"""
The NodeHandler is an abstract class used to generate every possible strategies for an operator node.
Args:
node (Node): the input node in node argument list.
device_mesh (DeviceMesh): A logical view of a physical mesh.
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
'''
def __init__(self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
shard_option: ShardOption = ShardOption.STANDARD,
solver_perference: SolverPerference = SolverPerference.STANDARD) -> None:
"""
def __init__(
self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
shard_option: ShardOption = ShardOption.STANDARD,
solver_perference: SolverPerference = SolverPerference.STANDARD,
) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
......@@ -68,8 +69,9 @@ class NodeHandler(ABC):
current_sharding_spec = strategy.sharding_specs[op_data]
# get the sharding specs for this node generated
# in its own node handler
assert hasattr(node, 'strategies_vector'), \
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
assert hasattr(
node, "strategies_vector"
), f"The predecessor node {node_name} has no strategy vector to compute the resharding cost."
prev_strategy_vector = node.strategies_vector
prev_sharding_specs = [
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
......@@ -80,10 +82,10 @@ class NodeHandler(ABC):
resharding_costs[node] = []
def _compute_resharding_cost(
prev_sharding_spec: Union[ShardingSpec,
List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec,
List[ShardingSpec]],
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem:
prev_sharding_spec: Union[ShardingSpec, List[ShardingSpec]],
current_sharding_spec: Union[ShardingSpec, List[ShardingSpec]],
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
) -> TrainCycleItem:
"""
This is a helper function to compute the resharding cost for a specific strategy of a node.
"""
......@@ -94,30 +96,35 @@ class NodeHandler(ABC):
dtype = data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
prev_sharding_spec, current_sharding_spec)
resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
bwd=consistency_cost["backward"] * size_per_elem_bytes,
total=consistency_cost["total"] * size_per_elem_bytes)
prev_sharding_spec, current_sharding_spec
)
resharding_cost = TrainCycleItem(
fwd=consistency_cost["forward"] * size_per_elem_bytes,
bwd=consistency_cost["backward"] * size_per_elem_bytes,
total=consistency_cost["total"] * size_per_elem_bytes,
)
return resharding_cost
else:
# This raise is used to check if we have missed any type of data.
# It could be merged into Parameter branch, which means we won't handle
# non-tensor arguments.
raise ValueError(f'Unsupported data type {type(data)}')
raise ValueError(f"Unsupported data type {type(data)}")
else:
assert isinstance(prev_sharding_spec, (tuple, list)), \
f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}'
assert isinstance(
prev_sharding_spec, (tuple, list)
), f"prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}"
fwd_cost = 0
bwd_cost = 0
total_cost = 0
for index, (prev_sharding_spec_item,
current_sharding_spec_item) in enumerate(zip(prev_sharding_spec,
current_sharding_spec)):
item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item,
data[index])
for index, (prev_sharding_spec_item, current_sharding_spec_item) in enumerate(
zip(prev_sharding_spec, current_sharding_spec)
):
item_cost = _compute_resharding_cost(
prev_sharding_spec_item, current_sharding_spec_item, data[index]
)
fwd_cost += item_cost.fwd
bwd_cost += item_cost.bwd
total_cost += item_cost.total
......@@ -138,17 +145,17 @@ class NodeHandler(ABC):
This function is used to get the target function for the node handler.
The target function is used to analyze the costs of strategies.
"""
if self.node.op in ('placeholder', 'get_attr', 'output'):
if self.node.op in ("placeholder", "get_attr", "output"):
return None
if self.node.op == 'call_module':
if self.node.op == "call_module":
target = self.node.graph.owning_module.get_submodule(self.node.target)
elif self.node.op == 'call_function':
elif self.node.op == "call_function":
target = self.node.target
elif self.node.op == 'call_method':
elif self.node.op == "call_method":
target = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
else:
raise ValueError(f'Unsupported node type: {self.node.op}')
raise ValueError(f"Unsupported node type: {self.node.op}")
return target
......@@ -221,7 +228,6 @@ class NodeHandler(ABC):
"""
Define which generators should be used by this NodeHandler object.
"""
pass
@abstractmethod
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......@@ -244,7 +250,6 @@ class NodeHandler(ABC):
"output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data),
}
"""
pass
class MetaInfoNodeHandler(NodeHandler):
......@@ -278,19 +283,19 @@ class MetaInfoNodeHandler(NodeHandler):
else:
logger = get_dist_logger()
logger.warning(f'The target function {target} is not patched yet, ')
logger.warning(f"The target function {target} is not patched yet, ")
return self.strategies_vector
class ModuleHandler(NodeHandler):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# set attributes to access module parameters for convenience
assert self.node.graph.owning_module is not None, \
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
assert (
self.node.graph.owning_module is not None
), f"The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object."
module = self.node.graph.owning_module.get_submodule(self.node.target)
named_parameters = list(module.named_parameters(recurse=False))
named_buffers = list(module.named_buffers(recurse=False))
......@@ -333,6 +338,6 @@ class MetaInfoModuleHandler(ModuleHandler):
else:
logger = get_dist_logger()
logger.warning(f'The target function {target} is not patched yet')
logger.warning(f"The target function {target} is not patched yet")
return self.strategies_vector
......@@ -3,11 +3,11 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoModuleHandler, ModuleHandler
from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
__all__ = ['NormPoolingHandler']
__all__ = ["NormPoolingHandler"]
@operator_registry.register(torch.nn.MaxPool1d)
......@@ -30,9 +30,9 @@ class NormPoolingHandler(MetaInfoModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_input_operand = OperationData(
name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
)
physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
......
......@@ -8,7 +8,7 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect
from .node_handler import NodeHandler
from .strategy import OutputGenerator, StrategyGenerator
__all__ = ['OutputHandler']
__all__ = ["OutputHandler"]
class OutputHandler(NodeHandler):
......@@ -16,8 +16,9 @@ class OutputHandler(NodeHandler):
A OutputHandler which deals with the sharding strategies for Output Node.
"""
def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
output_option: str) -> None:
def __init__(
self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, output_option: str
) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.output_option = output_option
......@@ -35,11 +36,11 @@ class OutputHandler(NodeHandler):
for index, input_node in enumerate(self.predecessor_node):
input_meta_data = input_node._meta_data
physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data)
name_key = f'input_{index}'
name_key = f"input_{index}"
mapping[name_key] = physical_inputs
output_meta_data.append(input_meta_data)
assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.'
assert len(output_meta_data) > 0, f"Output node {self.node} has no input node."
if len(output_meta_data) == 1:
output_meta_data = output_meta_data[0]
else:
......
......@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import PermuteGenerator, StrategyGenerator
__all__ = ['PermuteHandler']
__all__ = ["PermuteHandler"]
@operator_registry.register(torch.Tensor.permute)
......@@ -34,14 +34,14 @@ class PermuteHandler(NodeHandler):
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
permute_dims = []
if self.node.op == 'call_method':
if self.node.op == "call_method":
# torch.Tensor.permute (input, *dims)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, int):
permute_dims.append(arg._meta_data)
else:
assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.'
assert isinstance(arg, int), "The argument in permute node should be either type of Node or int."
permute_dims.append(arg)
else:
# torch.permute (input, dims)
......@@ -51,8 +51,8 @@ class PermuteHandler(NodeHandler):
permute_dims.extend(arg._meta_data)
else:
assert isinstance(
arg,
(tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].'
arg, (tuple, list)
), "The argument in permute node should be type of Node, Tuple[int] or List[int]."
permute_dims.extend(arg)
num_dims = self.node._meta_data.dim()
......@@ -61,7 +61,7 @@ class PermuteHandler(NodeHandler):
if permute_dims[i] < 0:
permute_dims[i] += num_dims
physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims))
physical_shape_operand = OperationData(name="permute_dims", type=OperationDataType.ARG, data=list(permute_dims))
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
......@@ -69,7 +69,7 @@ class PermuteHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"permute_dims": physical_shape_operand,
"output": physical_output_operand
"output": physical_output_operand,
}
return mapping
......@@ -8,7 +8,7 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect
from .node_handler import NodeHandler
from .strategy import PlaceholderGenerator, StrategyGenerator
__all__ = ['PlaceholderHandler']
__all__ = ["PlaceholderHandler"]
class PlaceholderHandler(NodeHandler):
......@@ -16,8 +16,9 @@ class PlaceholderHandler(NodeHandler):
A PlaceholderHandler which deals with the sharding strategies for Placeholder Node.
"""
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
placeholder_option: str) -> None:
def __init__(
self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, placeholder_option: str
) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.placeholder_option = placeholder_option
......@@ -25,7 +26,8 @@ class PlaceholderHandler(NodeHandler):
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option))
PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)
)
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......
class Registry:
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
......@@ -18,7 +16,7 @@ class Registry:
return wrapper
def get(self, source):
assert source in self.store, f'{source} not found in the {self.name} registry'
assert source in self.store, f"{source} not found in the {self.name} registry"
target = self.store[source]
return target
......@@ -26,4 +24,4 @@ class Registry:
return source in self.store
operator_registry = Registry('operator')
operator_registry = Registry("operator")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment