Commit 8adb146e authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

don't register @legacy as part of export method name

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/119

Reviewed By: zhanghang1989

Differential Revision: D31181216

fbshipit-source-id: 428116f4f4144e20410222825a9a00f75253ef4a
parent 1ce9e124
......@@ -136,7 +136,15 @@ def _is_data_flattened_tensors(data):
def tracing_adapter_wrap_export(old_f):
def new_f(cls, model, input_args, *args, **kwargs):
def new_f(cls, model, input_args, save_path, export_method, **export_kwargs):
force_disable_tracing_adapter = export_kwargs.pop(
"force_disable_tracing_adapter", False
)
if force_disable_tracing_adapter:
logger.info("Not trace mode, export normally")
return old_f(
cls, model, input_args, save_path, export_method, **export_kwargs
)
if _is_data_flattened_tensors(input_args):
# TODO: only dry-run for traceing
......@@ -146,7 +154,9 @@ def tracing_adapter_wrap_export(old_f):
logger.info(
"Both inputs and outputs are flattened tensors, export the model as is."
)
load_kwargs = old_f(cls, model, input_args, *args, **kwargs)
load_kwargs = old_f(
cls, model, input_args, save_path, export_method, **export_kwargs
)
assert "tracing_adapted" not in load_kwargs
load_kwargs.update({"tracing_adapted": False})
return load_kwargs
......@@ -162,7 +172,14 @@ def tracing_adapter_wrap_export(old_f):
" please be aware that the exported model will have different input/output data structure."
)
adapter = TracingAdapter(model, input_args)
load_kwargs = old_f(cls, adapter, adapter.flattened_inputs, *args, **kwargs)
load_kwargs = old_f(
cls,
adapter,
adapter.flattened_inputs,
save_path,
export_method,
**export_kwargs
)
inputs_schema = dump_dataclass(adapter.inputs_schema)
outputs_schema = dump_dataclass(adapter.outputs_schema)
assert "tracing_adapted" not in load_kwargs
......@@ -214,10 +231,6 @@ def tracing_adapter_wrap_load(old_f):
return new_f
@ModelExportMethodRegistry.register("torchscript@legacy")
@ModelExportMethodRegistry.register("torchscript_int8@legacy")
@ModelExportMethodRegistry.register("torchscript_mobile@legacy")
@ModelExportMethodRegistry.register("torchscript_mobile_int8@legacy")
class DefaultTorchscriptExport(ModelExportMethod):
@classmethod
def export(
......
......@@ -80,12 +80,17 @@ def default_rcnn_prepare_for_export(self, cfg, inputs, predictor_type):
)
preprocess_func = preprocess_info.instantiate()
model_export_kwargs = {}
if "torchscript" in predictor_type:
model_export_kwargs["force_disable_tracing_adapter"] = True
return PredictorExportConfig(
model=c2_compatible_model,
# Caffe2MetaArch takes a single tuple as input (which is the return of
# preprocess_func), data_generator requires all positional args as a tuple.
data_generator=lambda x: (preprocess_func(x),),
model_export_method=predictor_type.replace("@legacy", "", 1),
model_export_kwargs=model_export_kwargs,
preprocess_info=preprocess_info,
postprocess_info=postprocess_info,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment