Commit e992359c authored by Hang Zhang's avatar Hang Zhang Committed by Facebook GitHub Bot
Browse files

Integrating model profiler into d2go NAS task

Reviewed By: larryliu0820

Differential Revision: D30390706

fbshipit-source-id: 49f83f884f497df227448f7e59903bd1bd6e5484
parent 81328bf2
......@@ -88,21 +88,12 @@ class PredictorExportConfig(NamedTuple):
run_func_info: FuncInfo = FuncInfo.gen_func_info(NaiveRunFunc, params={})
def convert_and_export_predictor(
def convert_predictor(
cfg,
pytorch_model,
predictor_type,
output_dir,
data_loader,
):
"""
Entry point for convert and export model. This involves two steps:
- convert: converting the given `pytorch_model` to another format, currently
mainly for quantizing the model.
- export: exporting the converted `pytorch_model` to predictor. This step
should not alter the behaviour of model.
"""
if "int8" in predictor_type:
if not cfg.QUANTIZATION.QAT.ENABLED:
logger.info(
......@@ -131,7 +122,24 @@ def convert_and_export_predictor(
logger.info("Fused Model:\n{}".format(pytorch_model))
if fuse_utils.count_bn_exist(pytorch_model) > 0:
logger.warning("BN existed in pytorch model after fusing.")
return pytorch_model
def convert_and_export_predictor(
cfg,
pytorch_model,
predictor_type,
output_dir,
data_loader,
):
"""
Entry point for convert and export model. This involves two steps:
- convert: converting the given `pytorch_model` to another format, currently
mainly for quantizing the model.
- export: exporting the converted `pytorch_model` to predictor. This step
should not alter the behaviour of model.
"""
pytorch_model = convert_predictor(cfg, pytorch_model, predictor_type, data_loader)
return export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader)
......
......@@ -140,6 +140,19 @@ def tracing_adapter_wrap_export(old_f):
return new_f
class TracingAdapterModelWrapper(nn.Module):
def __init__(self, traced_model, inputs_schema, outputs_schema):
super().__init__()
self.traced_model = traced_model
self.inputs_schema = inputs_schema
self.outputs_schema = outputs_schema
def forward(self, *input_args):
flattened_inputs, _ = flatten_to_tuple(input_args)
flattened_outputs = self.traced_model(*flattened_inputs)
return self.outputs_schema(flattened_outputs)
def tracing_adapter_wrap_load(old_f):
def new_f(cls, save_path, **load_kwargs):
assert "inputs_schema" in load_kwargs, load_kwargs.keys()
......@@ -148,18 +161,6 @@ def tracing_adapter_wrap_load(old_f):
outputs_schema = instantiate(load_kwargs.pop("outputs_schema"))
traced_model = old_f(cls, save_path, **load_kwargs)
class TracingAdapterModelWrapper(nn.Module):
def __init__(self, traced_model, inputs_schema, outputs_schema):
super().__init__()
self.traced_model = traced_model
self.inputs_schema = inputs_schema
self.outputs_schema = outputs_schema
def forward(self, *input_args):
flattened_inputs, _ = flatten_to_tuple(input_args)
flattened_outputs = self.traced_model(*flattened_inputs)
return self.outputs_schema(flattened_outputs)
return TracingAdapterModelWrapper(traced_model, inputs_schema, outputs_schema)
return new_f
......
......@@ -62,11 +62,16 @@ class GeneralizedRCNNPatch:
def default_rcnn_prepare_for_export(self, cfg, inputs, predictor_type):
if "torchscript" in predictor_type and "@tracing" in predictor_type:
preprocess_info = FuncInfo.gen_func_info(
D2RCNNTracingWrapper.Preprocess, params={}
)
preprocess_func = preprocess_info.instantiate()
return PredictorExportConfig(
model=D2RCNNTracingWrapper(self),
data_generator=D2RCNNTracingWrapper.generator_trace_inputs,
run_func_info=FuncInfo.gen_func_info(
D2RCNNTracingWrapper.RunFunc, params={}
data_generator=lambda x: (preprocess_func(x),),
preprocess_info=preprocess_info,
postprocess_info=FuncInfo.gen_func_info(
D2RCNNTracingWrapper.Postprocess, params={}
),
)
......
......@@ -100,15 +100,18 @@ class D2RCNNTracingWrapper(nn.Module):
return self.model.inference(inputs, do_postprocess=False)[0]
@staticmethod
def generator_trace_inputs(batch):
class Preprocess(object):
"""
This function describes how to covert orginal input (from the data loader)
to the inputs used during the tracing (i.e. the inputs of forward function).
"""
return (batch[0]["image"],)
class RunFunc(object):
def __call__(self, tracing_adapter_wrapper, batch):
def __call__(self, batch):
assert len(batch) == 1, "only support single batch"
return batch[0]["image"]
class Postprocess(object):
def __call__(self, batch, inputs, outputs):
"""
This function describes how to run the predictor using exported model. Note
that `tracing_adapter_wrapper` runs the traced model under the hood and
......@@ -116,7 +119,5 @@ class D2RCNNTracingWrapper(nn.Module):
"""
assert len(batch) == 1, "only support single batch"
width, height = batch[0]["width"], batch[0]["height"]
inputs = D2RCNNTracingWrapper.generator_trace_inputs(batch)
results_per_image = tracing_adapter_wrapper(inputs)
r = detector_postprocess(results_per_image, height, width)
r = detector_postprocess(outputs, height, width)
return [{"instances": r}]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment