"src/vscode:/vscode.git/clone" did not exist on "2af3f22ec31537eb3310b98518b334b34e6cb08b"
Unverified Commit 5e04d56c authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Replace example_inputs with dummy_input (#3983)

parent bb397480
......@@ -32,10 +32,10 @@ To support latency-aware NAS, you first need a `Strategy` that supports filterin
``LatencyFilter`` will predict the models\' latency by using nn-Meter and filter out the models whose latency are larger than the threshold (i.e., ``100`` in this example).
You can also build your own strategies and filters to support more flexible NAS such as sorting the models according to latency.
Then, pass this strategy to ``RetiariiExperiment`` along with some additional arguments: ``parse_shape=True, example_inputs=example_inputs``:
Then, pass this strategy to ``RetiariiExperiment`` along with some additional arguments: ``parse_shape=True, dummy_input=dummy_input``:
.. code-block:: python
RetiariiExperiment(base_model, trainer, [], simple_strategy, True, example_inputs)
RetiariiExperiment(base_model, trainer, [], simple_strategy, True, dummy_input)
Here, ``parse_shape=True`` means extracting shape info from the torch model as it is required by nn-Meter to predict latency. ``example_inputs`` is required for tracing shape info.
Here, ``parse_shape=True`` means extracting shape info from the torch model as it is required by nn-Meter to predict latency. ``dummy_input`` is required for tracing shape info.
......@@ -185,7 +185,7 @@ def _main(port):
exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = False
exp_config.execution_engine = 'base'
exp_config.example_inputs = [1, 3, 32, 32]
exp_config.dummy_input = [1, 3, 32, 32]
exp.run(exp_config, port)
......
......@@ -683,13 +683,13 @@ class GraphConverterWithShape(GraphConverter):
If forward path of candidates depends on input data, then wrong path will be traced.
This will result in incomplete shape info.
"""
def convert_module(self, script_module, module, module_name, ir_model, example_inputs):
def convert_module(self, script_module, module, module_name, ir_model, dummy_input):
module.eval()
ir_graph, attrs = self._convert_module(script_module, module, module_name, ir_model)
self.remove_dummy_nodes(ir_model)
self._initialize_parameters(ir_model)
self._trace_module(module, module_name, ir_model, example_inputs)
self._trace_module(module, module_name, ir_model, dummy_input)
return ir_graph, attrs
def _initialize_parameters(self, ir_model: 'Model'):
......@@ -699,9 +699,9 @@ class GraphConverterWithShape(GraphConverter):
ir_node.operation.parameters.setdefault('input_shape', [])
ir_node.operation.parameters.setdefault('output_shape', [])
def _trace_module(self, module, module_name, ir_model: 'Model', example_inputs):
def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input):
# First, trace the whole graph
tm_graph = self._trace(module, example_inputs)
tm_graph = self._trace(module, dummy_input)
for node in tm_graph.nodes():
parameters = _extract_info_from_trace_node(node)
......@@ -832,8 +832,8 @@ class GraphConverterWithShape(GraphConverter):
# remove subgraphs
ir_model.graphs = {ir_model._root_graph_name: ir_model.root_graph}
def _trace(self, module, example_inputs):
traced_module = torch.jit.trace(module, example_inputs)
def _trace(self, module, dummy_input):
traced_module = torch.jit.trace(module, dummy_input)
torch._C._jit_pass_inline(traced_module.graph)
return traced_module.graph
......
......@@ -60,7 +60,7 @@ class RetiariiExeConfig(ConfigBase):
execution_engine: str = 'py'
# input used in GraphConverterWithShape. Currently support shape tuple only.
example_inputs: Optional[List[int]] = None
dummy_input: Optional[List[int]] = None
def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
......@@ -110,7 +110,7 @@ _validation_rules = {
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}
def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, example_inputs=None):
def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_input=None):
# TODO: this logic might need to be refactored into execution engine
if full_ir:
try:
......@@ -118,11 +118,11 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, exampl
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
if example_inputs is not None:
if dummy_input is not None:
# FIXME: this is a workaround as full tensor is not supported in configs
example_inputs = torch.randn(*example_inputs)
dummy_input = torch.randn(*dummy_input)
converter = GraphConverterWithShape()
base_model_ir = convert_to_graph(script_module, base_model, converter, example_inputs=example_inputs)
base_model_ir = convert_to_graph(script_module, base_model, converter, dummy_input=dummy_input)
else:
base_model_ir = convert_to_graph(script_module, base_model)
# handle inline mutations
......@@ -182,7 +182,7 @@ class RetiariiExperiment(Experiment):
def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py',
example_inputs=self.config.example_inputs)
dummy_input=self.config.dummy_input)
_logger.info('Start strategy...')
self.strategy.run(base_model_ir, self.applied_mutators)
......
......@@ -15,5 +15,5 @@ class ConvertWithShapeMixin:
@staticmethod
def _convert_model(model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), example_inputs=input)
model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), dummy_input=input)
return model_ir
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment