"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "0c17e2d9d47cd05a2d42bb9ef9540f492dd3bbfc"
Unverified Commit 5e04d56c authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Replace example_inputs with dummy_input (#3983)

parent bb397480
...@@ -32,10 +32,10 @@ To support latency-aware NAS, you first need a `Strategy` that supports filterin ...@@ -32,10 +32,10 @@ To support latency-aware NAS, you first need a `Strategy` that supports filterin
``LatencyFilter`` will predict the models\' latency by using nn-Meter and filter out the models whose latency are larger than the threshold (i.e., ``100`` in this example). ``LatencyFilter`` will predict the models\' latency by using nn-Meter and filter out the models whose latency are larger than the threshold (i.e., ``100`` in this example).
You can also build your own strategies and filters to support more flexible NAS such as sorting the models according to latency. You can also build your own strategies and filters to support more flexible NAS such as sorting the models according to latency.
Then, pass this strategy to ``RetiariiExperiment`` along with some additional arguments: ``parse_shape=True, example_inputs=example_inputs``: Then, pass this strategy to ``RetiariiExperiment`` along with some additional arguments: ``parse_shape=True, dummy_input=dummy_input``:
.. code-block:: python .. code-block:: python
RetiariiExperiment(base_model, trainer, [], simple_strategy, True, example_inputs) RetiariiExperiment(base_model, trainer, [], simple_strategy, True, dummy_input)
Here, ``parse_shape=True`` means extracting shape info from the torch model as it is required by nn-Meter to predict latency. ``example_inputs`` is required for tracing shape info. Here, ``parse_shape=True`` means extracting shape info from the torch model as it is required by nn-Meter to predict latency. ``dummy_input`` is required for tracing shape info.
...@@ -185,7 +185,7 @@ def _main(port): ...@@ -185,7 +185,7 @@ def _main(port):
exp_config.trial_gpu_number = 1 exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = False exp_config.training_service.use_active_gpu = False
exp_config.execution_engine = 'base' exp_config.execution_engine = 'base'
exp_config.example_inputs = [1, 3, 32, 32] exp_config.dummy_input = [1, 3, 32, 32]
exp.run(exp_config, port) exp.run(exp_config, port)
......
...@@ -683,13 +683,13 @@ class GraphConverterWithShape(GraphConverter): ...@@ -683,13 +683,13 @@ class GraphConverterWithShape(GraphConverter):
If forward path of candidates depends on input data, then wrong path will be traced. If forward path of candidates depends on input data, then wrong path will be traced.
This will result in incomplete shape info. This will result in incomplete shape info.
""" """
def convert_module(self, script_module, module, module_name, ir_model, example_inputs): def convert_module(self, script_module, module, module_name, ir_model, dummy_input):
module.eval() module.eval()
ir_graph, attrs = self._convert_module(script_module, module, module_name, ir_model) ir_graph, attrs = self._convert_module(script_module, module, module_name, ir_model)
self.remove_dummy_nodes(ir_model) self.remove_dummy_nodes(ir_model)
self._initialize_parameters(ir_model) self._initialize_parameters(ir_model)
self._trace_module(module, module_name, ir_model, example_inputs) self._trace_module(module, module_name, ir_model, dummy_input)
return ir_graph, attrs return ir_graph, attrs
def _initialize_parameters(self, ir_model: 'Model'): def _initialize_parameters(self, ir_model: 'Model'):
...@@ -699,9 +699,9 @@ class GraphConverterWithShape(GraphConverter): ...@@ -699,9 +699,9 @@ class GraphConverterWithShape(GraphConverter):
ir_node.operation.parameters.setdefault('input_shape', []) ir_node.operation.parameters.setdefault('input_shape', [])
ir_node.operation.parameters.setdefault('output_shape', []) ir_node.operation.parameters.setdefault('output_shape', [])
def _trace_module(self, module, module_name, ir_model: 'Model', example_inputs): def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input):
# First, trace the whole graph # First, trace the whole graph
tm_graph = self._trace(module, example_inputs) tm_graph = self._trace(module, dummy_input)
for node in tm_graph.nodes(): for node in tm_graph.nodes():
parameters = _extract_info_from_trace_node(node) parameters = _extract_info_from_trace_node(node)
...@@ -832,8 +832,8 @@ class GraphConverterWithShape(GraphConverter): ...@@ -832,8 +832,8 @@ class GraphConverterWithShape(GraphConverter):
# remove subgraphs # remove subgraphs
ir_model.graphs = {ir_model._root_graph_name: ir_model.root_graph} ir_model.graphs = {ir_model._root_graph_name: ir_model.root_graph}
def _trace(self, module, example_inputs): def _trace(self, module, dummy_input):
traced_module = torch.jit.trace(module, example_inputs) traced_module = torch.jit.trace(module, dummy_input)
torch._C._jit_pass_inline(traced_module.graph) torch._C._jit_pass_inline(traced_module.graph)
return traced_module.graph return traced_module.graph
......
...@@ -60,7 +60,7 @@ class RetiariiExeConfig(ConfigBase): ...@@ -60,7 +60,7 @@ class RetiariiExeConfig(ConfigBase):
execution_engine: str = 'py' execution_engine: str = 'py'
# input used in GraphConverterWithShape. Currently support shape tuple only. # input used in GraphConverterWithShape. Currently support shape tuple only.
example_inputs: Optional[List[int]] = None dummy_input: Optional[List[int]] = None
def __init__(self, training_service_platform: Optional[str] = None, **kwargs): def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -110,7 +110,7 @@ _validation_rules = { ...@@ -110,7 +110,7 @@ _validation_rules = {
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class') 'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
} }
def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, example_inputs=None): def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_input=None):
# TODO: this logic might need to be refactored into execution engine # TODO: this logic might need to be refactored into execution engine
if full_ir: if full_ir:
try: try:
...@@ -118,11 +118,11 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, exampl ...@@ -118,11 +118,11 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, exampl
except Exception as e: except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e raise e
if example_inputs is not None: if dummy_input is not None:
# FIXME: this is a workaround as full tensor is not supported in configs # FIXME: this is a workaround as full tensor is not supported in configs
example_inputs = torch.randn(*example_inputs) dummy_input = torch.randn(*dummy_input)
converter = GraphConverterWithShape() converter = GraphConverterWithShape()
base_model_ir = convert_to_graph(script_module, base_model, converter, example_inputs=example_inputs) base_model_ir = convert_to_graph(script_module, base_model, converter, dummy_input=dummy_input)
else: else:
base_model_ir = convert_to_graph(script_module, base_model) base_model_ir = convert_to_graph(script_module, base_model)
# handle inline mutations # handle inline mutations
...@@ -182,7 +182,7 @@ class RetiariiExperiment(Experiment): ...@@ -182,7 +182,7 @@ class RetiariiExperiment(Experiment):
def _start_strategy(self): def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model( base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py', self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py',
example_inputs=self.config.example_inputs) dummy_input=self.config.dummy_input)
_logger.info('Start strategy...') _logger.info('Start strategy...')
self.strategy.run(base_model_ir, self.applied_mutators) self.strategy.run(base_model_ir, self.applied_mutators)
......
...@@ -15,5 +15,5 @@ class ConvertWithShapeMixin: ...@@ -15,5 +15,5 @@ class ConvertWithShapeMixin:
@staticmethod @staticmethod
def _convert_model(model, input): def _convert_model(model, input):
script_module = torch.jit.script(model) script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), example_inputs=input) model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), dummy_input=input)
return model_ir return model_ir
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment