Unverified Commit 1a359941 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] support fucntion in operator handler (#1529)

parent 44c866a3
...@@ -9,7 +9,7 @@ __all__ = ['ConvHandler'] ...@@ -9,7 +9,7 @@ __all__ = ['ConvHandler']
class ConvHandler(OperatorHandler): class ConvHandler(OperatorHandler):
""" """
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication. An OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -67,7 +67,7 @@ class ConvHandler(OperatorHandler): ...@@ -67,7 +67,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
...@@ -106,7 +106,7 @@ class ConvHandler(OperatorHandler): ...@@ -106,7 +106,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
...@@ -145,7 +145,7 @@ class ConvHandler(OperatorHandler): ...@@ -145,7 +145,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
...@@ -184,7 +184,7 @@ class ConvHandler(OperatorHandler): ...@@ -184,7 +184,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]
...@@ -223,7 +223,7 @@ class ConvHandler(OperatorHandler): ...@@ -223,7 +223,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]
...@@ -261,7 +261,7 @@ class ConvHandler(OperatorHandler): ...@@ -261,7 +261,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]
...@@ -301,7 +301,7 @@ class ConvHandler(OperatorHandler): ...@@ -301,7 +301,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]
...@@ -340,7 +340,7 @@ class ConvHandler(OperatorHandler): ...@@ -340,7 +340,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]) bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
...@@ -380,7 +380,7 @@ class ConvHandler(OperatorHandler): ...@@ -380,7 +380,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]
......
...@@ -15,7 +15,7 @@ __all__ = ['OperatorHandler'] ...@@ -15,7 +15,7 @@ __all__ = ['OperatorHandler']
class OperatorHandler(ABC): class OperatorHandler(ABC):
''' '''
The OperatorHandler is an abstract class used to generate every possible strategies for a operator node. The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
Argument: Argument:
input_node(Node): the input node in node argument list. input_node(Node): the input node in node argument list.
...@@ -43,6 +43,10 @@ class OperatorHandler(ABC): ...@@ -43,6 +43,10 @@ class OperatorHandler(ABC):
named_parameters = list(module.named_parameters(recurse=False)) named_parameters = list(module.named_parameters(recurse=False))
# convert named parameters from list to dict # convert named parameters from list to dict
named_parameters = {k: v for k, v in named_parameters} named_parameters = {k: v for k, v in named_parameters}
elif self.node.op == 'call_function':
module = None
parameters = list(self.node.args)[1]
named_parameters = {'weight': parameters._meta_data}
else: else:
module = None module = None
named_parameters = None named_parameters = None
......
...@@ -27,7 +27,13 @@ class StrategiesConstructor: ...@@ -27,7 +27,13 @@ class StrategiesConstructor:
Generate the sharding spec of the tensor based on the given dim_partition_dict Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding. where the key is the tensor dimension and the value is the mesh dimension for sharding.
""" """
meta_tensor = node._meta_data if hasattr(node, '_meta_data'):
meta_tensor = node._meta_data
elif isinstance(node, torch.Tensor):
meta_tensor = node
else:
raise RuntimeError(f'We cannot generate sharding spec for {type(node)} type.')
sharding_spec = ShardingSpec(device_mesh=self.device_mesh, sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=meta_tensor.shape, entire_shape=meta_tensor.shape,
dim_partition_dict=dim_partition_dict) dim_partition_dict=dim_partition_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment