Generate the sharding spec of the tensor based on the given dim_partition_dict.
Args:
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
ifisinstance(input_,Node):
asserthasattr(input_,'_meta_data'),f'The given node has no attribte _meta_data'
meta_tensor=input_._meta_data
assertmeta_tensorisnotNone,"The given node's _meta_data attribute is None"
shape=meta_tensor.shape
elifisinstance(input_,torch.Tensor):
shape=input_.shape
else:
raiseTypeError(
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'