Commit 8adb146e authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

don't register @legacy as part of export method name

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/119

Reviewed By: zhanghang1989

Differential Revision: D31181216

fbshipit-source-id: 428116f4f4144e20410222825a9a00f75253ef4a
parent 1ce9e124
...@@ -136,7 +136,15 @@ def _is_data_flattened_tensors(data): ...@@ -136,7 +136,15 @@ def _is_data_flattened_tensors(data):
def tracing_adapter_wrap_export(old_f): def tracing_adapter_wrap_export(old_f):
def new_f(cls, model, input_args, *args, **kwargs): def new_f(cls, model, input_args, save_path, export_method, **export_kwargs):
force_disable_tracing_adapter = export_kwargs.pop(
"force_disable_tracing_adapter", False
)
if force_disable_tracing_adapter:
logger.info("Not trace mode, export normally")
return old_f(
cls, model, input_args, save_path, export_method, **export_kwargs
)
if _is_data_flattened_tensors(input_args): if _is_data_flattened_tensors(input_args):
# TODO: only dry-run for traceing # TODO: only dry-run for traceing
...@@ -146,7 +154,9 @@ def tracing_adapter_wrap_export(old_f): ...@@ -146,7 +154,9 @@ def tracing_adapter_wrap_export(old_f):
logger.info( logger.info(
"Both inputs and outputs are flattened tensors, export the model as is." "Both inputs and outputs are flattened tensors, export the model as is."
) )
load_kwargs = old_f(cls, model, input_args, *args, **kwargs) load_kwargs = old_f(
cls, model, input_args, save_path, export_method, **export_kwargs
)
assert "tracing_adapted" not in load_kwargs assert "tracing_adapted" not in load_kwargs
load_kwargs.update({"tracing_adapted": False}) load_kwargs.update({"tracing_adapted": False})
return load_kwargs return load_kwargs
...@@ -162,7 +172,14 @@ def tracing_adapter_wrap_export(old_f): ...@@ -162,7 +172,14 @@ def tracing_adapter_wrap_export(old_f):
" please be aware that the exported model will have different input/output data structure." " please be aware that the exported model will have different input/output data structure."
) )
adapter = TracingAdapter(model, input_args) adapter = TracingAdapter(model, input_args)
load_kwargs = old_f(cls, adapter, adapter.flattened_inputs, *args, **kwargs) load_kwargs = old_f(
cls,
adapter,
adapter.flattened_inputs,
save_path,
export_method,
**export_kwargs
)
inputs_schema = dump_dataclass(adapter.inputs_schema) inputs_schema = dump_dataclass(adapter.inputs_schema)
outputs_schema = dump_dataclass(adapter.outputs_schema) outputs_schema = dump_dataclass(adapter.outputs_schema)
assert "tracing_adapted" not in load_kwargs assert "tracing_adapted" not in load_kwargs
...@@ -214,10 +231,6 @@ def tracing_adapter_wrap_load(old_f): ...@@ -214,10 +231,6 @@ def tracing_adapter_wrap_load(old_f):
return new_f return new_f
@ModelExportMethodRegistry.register("torchscript@legacy")
@ModelExportMethodRegistry.register("torchscript_int8@legacy")
@ModelExportMethodRegistry.register("torchscript_mobile@legacy")
@ModelExportMethodRegistry.register("torchscript_mobile_int8@legacy")
class DefaultTorchscriptExport(ModelExportMethod): class DefaultTorchscriptExport(ModelExportMethod):
@classmethod @classmethod
def export( def export(
......
...@@ -80,12 +80,17 @@ def default_rcnn_prepare_for_export(self, cfg, inputs, predictor_type): ...@@ -80,12 +80,17 @@ def default_rcnn_prepare_for_export(self, cfg, inputs, predictor_type):
) )
preprocess_func = preprocess_info.instantiate() preprocess_func = preprocess_info.instantiate()
model_export_kwargs = {}
if "torchscript" in predictor_type:
model_export_kwargs["force_disable_tracing_adapter"] = True
return PredictorExportConfig( return PredictorExportConfig(
model=c2_compatible_model, model=c2_compatible_model,
# Caffe2MetaArch takes a single tuple as input (which is the return of # Caffe2MetaArch takes a single tuple as input (which is the return of
# preprocess_func), data_generator requires all positional args as a tuple. # preprocess_func), data_generator requires all positional args as a tuple.
data_generator=lambda x: (preprocess_func(x),), data_generator=lambda x: (preprocess_func(x),),
model_export_method=predictor_type.replace("@legacy", "", 1),
model_export_kwargs=model_export_kwargs,
preprocess_info=preprocess_info, preprocess_info=preprocess_info,
postprocess_info=postprocess_info, postprocess_info=postprocess_info,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment