Commit e992359c authored by Hang Zhang's avatar Hang Zhang Committed by Facebook GitHub Bot
Browse files

Integrating model profiler into d2go NAS task

Reviewed By: larryliu0820

Differential Revision: D30390706

fbshipit-source-id: 49f83f884f497df227448f7e59903bd1bd6e5484
parent 81328bf2
...@@ -88,21 +88,12 @@ class PredictorExportConfig(NamedTuple): ...@@ -88,21 +88,12 @@ class PredictorExportConfig(NamedTuple):
run_func_info: FuncInfo = FuncInfo.gen_func_info(NaiveRunFunc, params={}) run_func_info: FuncInfo = FuncInfo.gen_func_info(NaiveRunFunc, params={})
def convert_and_export_predictor( def convert_predictor(
cfg, cfg,
pytorch_model, pytorch_model,
predictor_type, predictor_type,
output_dir,
data_loader, data_loader,
): ):
"""
Entry point for convert and export model. This involves two steps:
- convert: converting the given `pytorch_model` to another format, currently
mainly for quantizing the model.
- export: exporting the converted `pytorch_model` to predictor. This step
should not alter the behaviour of model.
"""
if "int8" in predictor_type: if "int8" in predictor_type:
if not cfg.QUANTIZATION.QAT.ENABLED: if not cfg.QUANTIZATION.QAT.ENABLED:
logger.info( logger.info(
...@@ -131,7 +122,24 @@ def convert_and_export_predictor( ...@@ -131,7 +122,24 @@ def convert_and_export_predictor(
logger.info("Fused Model:\n{}".format(pytorch_model)) logger.info("Fused Model:\n{}".format(pytorch_model))
if fuse_utils.count_bn_exist(pytorch_model) > 0: if fuse_utils.count_bn_exist(pytorch_model) > 0:
logger.warning("BN existed in pytorch model after fusing.") logger.warning("BN existed in pytorch model after fusing.")
return pytorch_model
def convert_and_export_predictor(
cfg,
pytorch_model,
predictor_type,
output_dir,
data_loader,
):
"""
Entry point for convert and export model. This involves two steps:
- convert: converting the given `pytorch_model` to another format, currently
mainly for quantizing the model.
- export: exporting the converted `pytorch_model` to predictor. This step
should not alter the behaviour of model.
"""
pytorch_model = convert_predictor(cfg, pytorch_model, predictor_type, data_loader)
return export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader) return export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader)
......
...@@ -140,6 +140,19 @@ def tracing_adapter_wrap_export(old_f): ...@@ -140,6 +140,19 @@ def tracing_adapter_wrap_export(old_f):
return new_f return new_f
class TracingAdapterModelWrapper(nn.Module):
def __init__(self, traced_model, inputs_schema, outputs_schema):
super().__init__()
self.traced_model = traced_model
self.inputs_schema = inputs_schema
self.outputs_schema = outputs_schema
def forward(self, *input_args):
flattened_inputs, _ = flatten_to_tuple(input_args)
flattened_outputs = self.traced_model(*flattened_inputs)
return self.outputs_schema(flattened_outputs)
def tracing_adapter_wrap_load(old_f): def tracing_adapter_wrap_load(old_f):
def new_f(cls, save_path, **load_kwargs): def new_f(cls, save_path, **load_kwargs):
assert "inputs_schema" in load_kwargs, load_kwargs.keys() assert "inputs_schema" in load_kwargs, load_kwargs.keys()
...@@ -148,18 +161,6 @@ def tracing_adapter_wrap_load(old_f): ...@@ -148,18 +161,6 @@ def tracing_adapter_wrap_load(old_f):
outputs_schema = instantiate(load_kwargs.pop("outputs_schema")) outputs_schema = instantiate(load_kwargs.pop("outputs_schema"))
traced_model = old_f(cls, save_path, **load_kwargs) traced_model = old_f(cls, save_path, **load_kwargs)
class TracingAdapterModelWrapper(nn.Module):
def __init__(self, traced_model, inputs_schema, outputs_schema):
super().__init__()
self.traced_model = traced_model
self.inputs_schema = inputs_schema
self.outputs_schema = outputs_schema
def forward(self, *input_args):
flattened_inputs, _ = flatten_to_tuple(input_args)
flattened_outputs = self.traced_model(*flattened_inputs)
return self.outputs_schema(flattened_outputs)
return TracingAdapterModelWrapper(traced_model, inputs_schema, outputs_schema) return TracingAdapterModelWrapper(traced_model, inputs_schema, outputs_schema)
return new_f return new_f
......
...@@ -62,11 +62,16 @@ class GeneralizedRCNNPatch: ...@@ -62,11 +62,16 @@ class GeneralizedRCNNPatch:
def default_rcnn_prepare_for_export(self, cfg, inputs, predictor_type): def default_rcnn_prepare_for_export(self, cfg, inputs, predictor_type):
if "torchscript" in predictor_type and "@tracing" in predictor_type: if "torchscript" in predictor_type and "@tracing" in predictor_type:
preprocess_info = FuncInfo.gen_func_info(
D2RCNNTracingWrapper.Preprocess, params={}
)
preprocess_func = preprocess_info.instantiate()
return PredictorExportConfig( return PredictorExportConfig(
model=D2RCNNTracingWrapper(self), model=D2RCNNTracingWrapper(self),
data_generator=D2RCNNTracingWrapper.generator_trace_inputs, data_generator=lambda x: (preprocess_func(x),),
run_func_info=FuncInfo.gen_func_info( preprocess_info=preprocess_info,
D2RCNNTracingWrapper.RunFunc, params={} postprocess_info=FuncInfo.gen_func_info(
D2RCNNTracingWrapper.Postprocess, params={}
), ),
) )
......
...@@ -100,15 +100,18 @@ class D2RCNNTracingWrapper(nn.Module): ...@@ -100,15 +100,18 @@ class D2RCNNTracingWrapper(nn.Module):
return self.model.inference(inputs, do_postprocess=False)[0] return self.model.inference(inputs, do_postprocess=False)[0]
@staticmethod @staticmethod
def generator_trace_inputs(batch): class Preprocess(object):
""" """
This function describes how to covert orginal input (from the data loader) This function describes how to covert orginal input (from the data loader)
to the inputs used during the tracing (i.e. the inputs of forward function). to the inputs used during the tracing (i.e. the inputs of forward function).
""" """
return (batch[0]["image"],)
class RunFunc(object): def __call__(self, batch):
def __call__(self, tracing_adapter_wrapper, batch): assert len(batch) == 1, "only support single batch"
return batch[0]["image"]
class Postprocess(object):
def __call__(self, batch, inputs, outputs):
""" """
This function describes how to run the predictor using exported model. Note This function describes how to run the predictor using exported model. Note
that `tracing_adapter_wrapper` runs the traced model under the hood and that `tracing_adapter_wrapper` runs the traced model under the hood and
...@@ -116,7 +119,5 @@ class D2RCNNTracingWrapper(nn.Module): ...@@ -116,7 +119,5 @@ class D2RCNNTracingWrapper(nn.Module):
""" """
assert len(batch) == 1, "only support single batch" assert len(batch) == 1, "only support single batch"
width, height = batch[0]["width"], batch[0]["height"] width, height = batch[0]["width"], batch[0]["height"]
inputs = D2RCNNTracingWrapper.generator_trace_inputs(batch) r = detector_postprocess(outputs, height, width)
results_per_image = tracing_adapter_wrapper(inputs)
r = detector_postprocess(results_per_image, height, width)
return [{"instances": r}] return [{"instances": r}]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment