Commit c6c089a6 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support exporting multiple sub models

Summary: model to export is either `nn.Module` and or dict of `nn.Module`

Reviewed By: zhanghang1989

Differential Revision: D27835097

fbshipit-source-id: 869446b36d3e8cc30d6d947f1fc8970cc9ba6c12
parent a3f4276c
...@@ -27,7 +27,7 @@ NOTE: ...@@ -27,7 +27,7 @@ NOTE:
import json import json
import logging import logging
import os import os
from typing import Any, Callable, Dict, NamedTuple, Optional, Union from typing import Callable, Dict, NamedTuple, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -69,12 +69,10 @@ class PredictorExportConfig(NamedTuple): ...@@ -69,12 +69,10 @@ class PredictorExportConfig(NamedTuple):
run_func_info (FuncInfo): info for predictor's run_fun run_func_info (FuncInfo): info for predictor's run_fun
""" """
model: Union[nn.Module, Any] model: Union[nn.Module, Dict[str, nn.Module]]
# Shall we save data_generator in the predictor? This might be necessary when data
# is needed, eg. running benchmark for sub models
data_generator: Optional[Callable] = None data_generator: Optional[Callable] = None
model_export_method: Optional[str] = None model_export_method: Optional[Union[str, Dict[str, str]]] = None
model_export_kwargs: Optional[Union[Dict, Any]] = None model_export_kwargs: Optional[Union[Dict, Dict[str, Dict]]] = None
preprocess_info: FuncInfo = FuncInfo.gen_func_info(IdentityPreprocess, params={}) preprocess_info: FuncInfo = FuncInfo.gen_func_info(IdentityPreprocess, params={})
postprocess_info: FuncInfo = FuncInfo.gen_func_info(IdentityPostprocess, params={}) postprocess_info: FuncInfo = FuncInfo.gen_func_info(IdentityPostprocess, params={})
...@@ -152,6 +150,27 @@ def export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader ...@@ -152,6 +150,27 @@ def export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader
) )
def _export_single_model(
predictor_path,
model,
input_args,
save_path,
model_export_method,
model_export_kwargs,
predictor_type, # TODO: remove this after refactoring ModelInfo
):
assert isinstance(model, nn.Module), model
load_kwargs = ModelExportMethodRegistry.get(model_export_method).export(
model=model,
input_args=input_args,
save_path=save_path,
**model_export_kwargs,
)
assert isinstance(load_kwargs, dict) # TODO: save this in predictor_info
model_rel_path = os.path.relpath(save_path, predictor_path)
return ModelInfo(path=model_rel_path, type=predictor_type)
def default_export_predictor( def default_export_predictor(
cfg, pytorch_model, predictor_type, output_dir, data_loader cfg, pytorch_model, predictor_type, output_dir, data_loader
): ):
...@@ -163,40 +182,59 @@ def default_export_predictor( ...@@ -163,40 +182,59 @@ def default_export_predictor(
export_config = pytorch_model.prepare_for_export( export_config = pytorch_model.prepare_for_export(
cfg, inputs, export_scheme=predictor_type cfg, inputs, export_scheme=predictor_type
) )
model_inputs = (
predictor_path = os.path.join(output_dir, predictor_type)
PathManager.mkdirs(predictor_path)
# TODO: also support multiple models from nested dict in the default implementation
assert isinstance(export_config.model, nn.Module), "Currently support single model"
model = export_config.model
input_args = (
export_config.data_generator(inputs) export_config.data_generator(inputs)
if export_config.data_generator is not None if export_config.data_generator is not None
else None else None
) )
model_export_method = export_config.model_export_method or predictor_type
model_export_kwargs = export_config.model_export_kwargs or {}
# the default implementation assumes model type is the same as the predictor type
model_type = predictor_type
model_path = predictor_path # might be sub dir for multiple models
load_kwargs = ModelExportMethodRegistry.get(model_export_method).export( predictor_path = os.path.join(output_dir, predictor_type)
model=model, PathManager.mkdirs(predictor_path)
input_args=input_args,
save_path=model_path, predictor_init_kwargs = {
**model_export_kwargs, "preprocess_info": export_config.preprocess_info,
) "postprocess_info": export_config.postprocess_info,
assert isinstance(load_kwargs, dict) # TODO: save this in predictor_info "run_func_info": export_config.run_func_info,
model_rel_path = os.path.relpath(model_path, predictor_path) }
if isinstance(export_config.model, dict):
models_info = {}
for name, model in export_config.model.items():
save_path = os.path.join(predictor_path, name)
model_info = _export_single_model(
predictor_path=predictor_path,
model=model,
input_args=model_inputs[name] if model_inputs is not None else None,
save_path=save_path,
model_export_method=(
predictor_type
if export_config.model_export_method is None
else export_config.model_export_method[name]
),
model_export_kwargs=(
{}
if export_config.model_export_kwargs is None
else export_config.model_export_kwargs[name]
),
predictor_type=predictor_type,
)
models_info[name] = model_info
predictor_init_kwargs["models"] = models_info
else:
save_path = predictor_path # for single model exported files are put under `predictor_path` together with predictor_info.json
model_info = _export_single_model(
predictor_path=predictor_path,
model=export_config.model,
input_args=model_inputs,
save_path=save_path,
model_export_method=export_config.model_export_method or predictor_type,
model_export_kwargs=export_config.model_export_kwargs or {},
predictor_type=predictor_type,
)
predictor_init_kwargs["model"] = model_info
# assemble predictor # assemble predictor
predictor_info = PredictorInfo( predictor_info = PredictorInfo(**predictor_init_kwargs)
model=ModelInfo(path=model_rel_path, type=model_type),
preprocess_info=export_config.preprocess_info,
postprocess_info=export_config.postprocess_info,
run_func_info=export_config.run_func_info,
)
with PathManager.open( with PathManager.open(
os.path.join(predictor_path, "predictor_info.json"), "w" os.path.join(predictor_path, "predictor_info.json"), "w"
) as f: ) as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment