Commit 3ce16d73 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

only wrap model with TracingAdapter when necessary

Summary:
D31134064 changes the default ExportMethod from `DefaultTorchscriptExport` to `D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)TorchscriptTracingExport` for all models. Without change, all models will be wrapped using `TracingAdapter`, which might cause unexpected effects (eg. it's not scripting friendly).

This diff add check for input/output data structure and only wrap the model when necessary.

Reviewed By: zhanghang1989

Differential Revision: D31136261

fbshipit-source-id: 4a8ffc986a5c5d61c493dd4ba0eb185aa0d54f38
parent dac9a358
......@@ -124,16 +124,56 @@ def load_torchscript(model_path):
return TorchscriptWrapper(ts)
def _is_data_flattened_tensors(data):
if isinstance(data, torch.Tensor):
return True
if isinstance(data, (tuple, list)):
if all(isinstance(x, torch.Tensor) for x in data):
return True
return False
def tracing_adapter_wrap_export(old_f):
def new_f(cls, model, input_args, *args, **kwargs):
if _is_data_flattened_tensors(input_args):
# TODO: only dry-run for traceing
logger.info("Dry run the model to check if TracingAdapter is needed ...")
outputs = model(*input_args)
if _is_data_flattened_tensors(outputs):
logger.info(
"Both inputs and outputs are flattened tensors, export the model as is."
)
load_kwargs = old_f(cls, model, input_args, *args, **kwargs)
assert "tracing_adapted" not in load_kwargs
load_kwargs.update({"tracing_adapted": False})
return load_kwargs
else:
logger.info(
"The outputs are not flattened tensors, can't trace normally."
)
else:
logger.info("The inputs are not flattened tensors, can't trace normally.")
logger.warning(
"Wrap the model with TracingAdapter to handle non-flattened inputs/outputs,"
" please be aware that the exported model will have different input/output data structure."
)
adapter = TracingAdapter(model, input_args)
load_kwargs = old_f(cls, adapter, adapter.flattened_inputs, *args, **kwargs)
inputs_schema = dump_dataclass(adapter.inputs_schema)
outputs_schema = dump_dataclass(adapter.outputs_schema)
assert "tracing_adapted" not in load_kwargs
assert "inputs_schema" not in load_kwargs
assert "outputs_schema" not in load_kwargs
load_kwargs.update(
{"inputs_schema": inputs_schema, "outputs_schema": outputs_schema}
{
"tracing_adapted": True,
"inputs_schema": inputs_schema,
"outputs_schema": outputs_schema,
}
)
return load_kwargs
......@@ -155,6 +195,14 @@ class TracingAdapterModelWrapper(nn.Module):
def tracing_adapter_wrap_load(old_f):
def new_f(cls, save_path, **load_kwargs):
tracing_adapted = load_kwargs.pop("tracing_adapted", False)
if not tracing_adapted:
logger.info("The model is not tracing adapted, load it normally.")
return old_f(cls, save_path, **load_kwargs)
logger.info(
"The model is tracing adapted, load the schema and wrap the model for inference."
)
assert "inputs_schema" in load_kwargs, load_kwargs.keys()
assert "outputs_schema" in load_kwargs, load_kwargs.keys()
inputs_schema = instantiate(load_kwargs.pop("inputs_schema"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment