Commit c6c089a6 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support exporting multiple sub models

Summary: model to export is either `nn.Module` and or dict of `nn.Module`

Reviewed By: zhanghang1989

Differential Revision: D27835097

fbshipit-source-id: 869446b36d3e8cc30d6d947f1fc8970cc9ba6c12
parent a3f4276c
......@@ -27,7 +27,7 @@ NOTE:
import json
import logging
import os
from typing import Any, Callable, Dict, NamedTuple, Optional, Union
from typing import Callable, Dict, NamedTuple, Optional, Union
import torch
import torch.nn as nn
......@@ -69,12 +69,10 @@ class PredictorExportConfig(NamedTuple):
run_func_info (FuncInfo): info for predictor's run_fun
"""
model: Union[nn.Module, Any]
# Shall we save data_generator in the predictor? This might be necessary when data
# is needed, eg. running benchmark for sub models
model: Union[nn.Module, Dict[str, nn.Module]]
data_generator: Optional[Callable] = None
model_export_method: Optional[str] = None
model_export_kwargs: Optional[Union[Dict, Any]] = None
model_export_method: Optional[Union[str, Dict[str, str]]] = None
model_export_kwargs: Optional[Union[Dict, Dict[str, Dict]]] = None
preprocess_info: FuncInfo = FuncInfo.gen_func_info(IdentityPreprocess, params={})
postprocess_info: FuncInfo = FuncInfo.gen_func_info(IdentityPostprocess, params={})
......@@ -152,6 +150,27 @@ def export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader
)
def _export_single_model(
predictor_path,
model,
input_args,
save_path,
model_export_method,
model_export_kwargs,
predictor_type, # TODO: remove this after refactoring ModelInfo
):
assert isinstance(model, nn.Module), model
load_kwargs = ModelExportMethodRegistry.get(model_export_method).export(
model=model,
input_args=input_args,
save_path=save_path,
**model_export_kwargs,
)
assert isinstance(load_kwargs, dict) # TODO: save this in predictor_info
model_rel_path = os.path.relpath(save_path, predictor_path)
return ModelInfo(path=model_rel_path, type=predictor_type)
def default_export_predictor(
cfg, pytorch_model, predictor_type, output_dir, data_loader
):
......@@ -163,40 +182,59 @@ def default_export_predictor(
export_config = pytorch_model.prepare_for_export(
cfg, inputs, export_scheme=predictor_type
)
predictor_path = os.path.join(output_dir, predictor_type)
PathManager.mkdirs(predictor_path)
# TODO: also support multiple models from nested dict in the default implementation
assert isinstance(export_config.model, nn.Module), "Currently support single model"
model = export_config.model
input_args = (
model_inputs = (
export_config.data_generator(inputs)
if export_config.data_generator is not None
else None
)
model_export_method = export_config.model_export_method or predictor_type
model_export_kwargs = export_config.model_export_kwargs or {}
# the default implementation assumes model type is the same as the predictor type
model_type = predictor_type
model_path = predictor_path # might be sub dir for multiple models
load_kwargs = ModelExportMethodRegistry.get(model_export_method).export(
predictor_path = os.path.join(output_dir, predictor_type)
PathManager.mkdirs(predictor_path)
predictor_init_kwargs = {
"preprocess_info": export_config.preprocess_info,
"postprocess_info": export_config.postprocess_info,
"run_func_info": export_config.run_func_info,
}
if isinstance(export_config.model, dict):
models_info = {}
for name, model in export_config.model.items():
save_path = os.path.join(predictor_path, name)
model_info = _export_single_model(
predictor_path=predictor_path,
model=model,
input_args=input_args,
save_path=model_path,
**model_export_kwargs,
input_args=model_inputs[name] if model_inputs is not None else None,
save_path=save_path,
model_export_method=(
predictor_type
if export_config.model_export_method is None
else export_config.model_export_method[name]
),
model_export_kwargs=(
{}
if export_config.model_export_kwargs is None
else export_config.model_export_kwargs[name]
),
predictor_type=predictor_type,
)
assert isinstance(load_kwargs, dict) # TODO: save this in predictor_info
model_rel_path = os.path.relpath(model_path, predictor_path)
models_info[name] = model_info
predictor_init_kwargs["models"] = models_info
else:
save_path = predictor_path # for single model exported files are put under `predictor_path` together with predictor_info.json
model_info = _export_single_model(
predictor_path=predictor_path,
model=export_config.model,
input_args=model_inputs,
save_path=save_path,
model_export_method=export_config.model_export_method or predictor_type,
model_export_kwargs=export_config.model_export_kwargs or {},
predictor_type=predictor_type,
)
predictor_init_kwargs["model"] = model_info
# assemble predictor
predictor_info = PredictorInfo(
model=ModelInfo(path=model_rel_path, type=model_type),
preprocess_info=export_config.preprocess_info,
postprocess_info=export_config.postprocess_info,
run_func_info=export_config.run_func_info,
)
predictor_info = PredictorInfo(**predictor_init_kwargs)
with PathManager.open(
os.path.join(predictor_path, "predictor_info.json"), "w"
) as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment