Generate the sharding spec of the tensor based on the given dim_partition_dict.
Args:
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
ifisinstance(input_,Node):
asserthasattr(input_,'_meta_data'),f'The given node has no attribte _meta_data'
meta_tensor=input_._meta_data
assertmeta_tensorisnotNone,"The given node's _meta_data attribute is None"
shape=meta_tensor.shape
elifisinstance(input_,torch.Tensor):
shape=input_.shape
else:
raiseTypeError(
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
name=f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_condition.sharding_sequence} x {sharding_spec_for_x.sharding_sequence} x {sharding_spec_for_y.sharding_sequence}'
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.