Unverified Commit be7eee0c authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

enable customizing the replace function (#4826)

parent 98c1a77f
...@@ -42,10 +42,23 @@ class ModelSpeedup: ...@@ -42,10 +42,23 @@ class ModelSpeedup:
the index of batch dimension in the dummy_input the index of batch dimension in the dummy_input
confidence: the confidence coefficient of the sparsity inference. This value is confidence: the confidence coefficient of the sparsity inference. This value is
actually used as the batchsize of the dummy_input. actually used as the batchsize of the dummy_input.
customized_replace_func: None/Dict
If `customized_replace_func` is not None, then we will use the given function to replace the
corresponding modules. The `key` of the dict is the opertor types and the `value`
is the replace function of corresponding opertor. The replace function should take
two input parameters, one is the original module, the second input parameter is tuple
of the input mask, output mask and weight mask. This replace function should prune the module
accordingly. Here is an example of the replace function(more examples can refer to compress_modules.py)::
def example_replace(ori_module, masks):
in_mask, out_mask, weight_mask = masks
# prune the ori_module to a new smaller module according to the mask
return new_small_module
""" """
def __init__(self, model, dummy_input, masks_file, map_location=None, def __init__(self, model, dummy_input, masks_file, map_location=None,
batch_dim=0, confidence=8): batch_dim=0, confidence=8, customized_replace_func=None):
assert confidence > 1 assert confidence > 1
# The auto inference will change the values of the parameters in the model # The auto inference will change the values of the parameters in the model
# so we need make a copy before the mask inference # so we need make a copy before the mask inference
...@@ -53,7 +66,8 @@ class ModelSpeedup: ...@@ -53,7 +66,8 @@ class ModelSpeedup:
self.bound_model = model self.bound_model = model
self.inferred_masks = dict() # key: module_name, value: ModuleMasks self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.batch_dim = batch_dim self.batch_dim = batch_dim
self.dummy_input, self.device = self._random_model_input(dummy_input, confidence, batch_dim) self.dummy_input, self.device = self._random_model_input(
dummy_input, confidence, batch_dim)
self.torch_graph = build_module_graph(model, self.dummy_input) self.torch_graph = build_module_graph(model, self.dummy_input)
# dict object to save the auto inferences objects of the submodules # dict object to save the auto inferences objects of the submodules
self.auto_inferences = {} self.auto_inferences = {}
...@@ -75,6 +89,7 @@ class ModelSpeedup: ...@@ -75,6 +89,7 @@ class ModelSpeedup:
self.constant = {} self.constant = {}
# self.internal_result save the internal output of the submodules # self.internal_result save the internal output of the submodules
self.internal_result = {} self.internal_result = {}
self.customized_replace_func = customized_replace_func if customized_replace_func is not None else {}
def _random_model_input(self, dummy_input, confidence, batch_dim): def _random_model_input(self, dummy_input, confidence, batch_dim):
""" """
...@@ -284,7 +299,8 @@ class ModelSpeedup: ...@@ -284,7 +299,8 @@ class ModelSpeedup:
else: else:
last_output.grad = tin.grad last_output.grad = tin.grad
else: else:
_logger.warning('Note: %s does not have corresponding mask inference object', node.name) _logger.warning(
'Note: %s does not have corresponding mask inference object', node.name)
def _vnode_to_value(self, c_node): def _vnode_to_value(self, c_node):
""" """
...@@ -408,6 +424,7 @@ class ModelSpeedup: ...@@ -408,6 +424,7 @@ class ModelSpeedup:
method is shutdown, in the future, we will merge these two methods into a graph method is shutdown, in the future, we will merge these two methods into a graph
pass which is used to resolve the mask conflict. pass which is used to resolve the mask conflict.
""" """
def __init__(self, ori_module, reindex_dim, reindex): def __init__(self, ori_module, reindex_dim, reindex):
super(ReindexModule, self).__init__() super(ReindexModule, self).__init__()
self.ori_module = ori_module self.ori_module = ori_module
...@@ -441,12 +458,15 @@ class ModelSpeedup: ...@@ -441,12 +458,15 @@ class ModelSpeedup:
super_module, leaf_module = get_module_by_name( super_module, leaf_module = get_module_by_name(
self.bound_model, g_node.name) self.bound_model, g_node.name)
m_type = g_node.op_type m_type = g_node.op_type
if not m_type in replace_module: if (not m_type in replace_module) and (m_type not in self.customized_replace_func):
raise RuntimeError( raise RuntimeError(
"Has not supported replacing the module: `{}`".format(m_type)) "Has not supported replacing the module: `{}`".format(m_type))
_logger.info("replace module (name: %s, op_type: %s)", _logger.info("replace module (name: %s, op_type: %s)",
g_node.name, m_type) g_node.name, m_type)
compressed_module = replace_module[m_type]( replace_function = replace_module[m_type]
if m_type in self.customized_replace_func:
replace_function = self.customized_replace_func[m_type]
compressed_module = replace_function(
leaf_module, auto_infer.get_masks()) leaf_module, auto_infer.get_masks())
new_submodule = compressed_module new_submodule = compressed_module
if reindex_dim is None: if reindex_dim is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment