Commit 18dc1374 authored by Hang Zhang's avatar Hang Zhang Committed by Facebook GitHub Bot
Browse files

hide caffe2 related code from oss

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/59

* We have an internal dependency:
```
d2go/export/logfiledb.py", line 8, in <module>
    from mobile_cv.torch.utils_caffe2.ws_utils import ScopedWS
    ModuleNotFoundError: No module named 'mobile_cv.torch'
```
This cause the failure of unittest on GitHub
https://github.com/facebookresearch/d2go/pull/58/checks?check_run_id=2471727763

* use python 3.8 because another unittest failure on github ci
```
from typing import final
ImportError: cannot import name 'final' from 'typing' (/usr/share/miniconda/lib/python3.7/typing.py)
```

Reviewed By: wat3rBro

Differential Revision: D28109444

fbshipit-source-id: 95e9774bdaa94f622267aeaac06d7448f37a103f
parent 95e1fa6e
......@@ -14,7 +14,7 @@ jobs:
- name: Set up Python
uses: s-weigand/setup-conda@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install Dependencies
run: |
......
......@@ -2,5 +2,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# enable registry
from . import caffe2 # noqa
from . import torchscript # noqa
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import os
from typing import Dict, Tuple
import torch
from d2go.export.api import ModelExportMethodRegistry, ModelExportMethod
from d2go.export.logfiledb import export_to_logfiledb
from detectron2.export.api import Caffe2Model
from detectron2.export.caffe2_export import (
export_caffe2_detection_model,
run_and_save_graph,
)
from torch import nn
logger = logging.getLogger(__name__)
def export_caffe2(
caffe2_compatible_model: nn.Module,
tensor_inputs: Tuple[str, torch.Tensor],
output_dir: str,
save_pb: bool = True,
save_logdb: bool = False,
) -> Tuple[Caffe2Model, Dict[str, str]]:
predict_net, init_net = export_caffe2_detection_model(
caffe2_compatible_model,
# pyre-fixme[6]: Expected `List[torch.Tensor]` for 2nd param but got
# `Tuple[str, torch.Tensor]`.
tensor_inputs,
)
caffe2_model = Caffe2Model(predict_net, init_net)
caffe2_export_paths = {}
if save_pb:
caffe2_model.save_protobuf(output_dir)
caffe2_export_paths.update(
{
"predict_net_path": os.path.join(output_dir, "model.pb"),
"init_net_path": os.path.join(output_dir, "model_init.pb"),
}
)
graph_save_path = os.path.join(output_dir, "model_def.svg")
ws_blobs = run_and_save_graph(
predict_net,
init_net,
tensor_inputs,
graph_save_path=graph_save_path,
)
caffe2_export_paths.update(
{
"model_def_path": graph_save_path,
}
)
if save_logdb:
logfiledb_path = os.path.join(output_dir, "model.logfiledb")
export_to_logfiledb(predict_net, init_net, logfiledb_path, ws_blobs)
caffe2_export_paths.update(
{
"logfiledb_path": logfiledb_path if save_logdb else None,
}
)
return caffe2_model, caffe2_export_paths
@ModelExportMethodRegistry.register("caffe2")
class DefaultCaffe2Export(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, export_method, **export_kwargs):
# HACK: workaround the current caffe2 export API
if not hasattr(model, "encode_additional_info"):
model.encode_additional_info = lambda predict_net, init_net: None
export_caffe2(model, input_args[0], save_path, **export_kwargs)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
from mobile_cv.predictor.model_wrappers import load_model
return load_model(save_path, "caffe2")
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from functools import partial
from detectron2.export.caffe2_inference import ProtobufDetectionModel
from d2go.config import temp_defrost
logger = logging.getLogger(__name__)
def infer_mask_on(model: ProtobufDetectionModel):
# the real self.assembler should tell about this, currently use heuristic
possible_blob_names = {"mask_fcn_probs"}
return any(
possible_blob_names.intersection(op.output)
for op in model.protobuf_model.net.Proto().op
)
def infer_keypoint_on(model: ProtobufDetectionModel):
# the real self.assembler should tell about this, currently use heuristic
possible_blob_names = {"kps_score"}
return any(
possible_blob_names.intersection(op.output)
for op in model.protobuf_model.net.Proto().op
)
def infer_densepose_on(model: ProtobufDetectionModel):
possible_blob_names = {"AnnIndex", "Index_UV", "U_estimated", "V_estimated"}
return any(
possible_blob_names.intersection(op.output)
for op in model.protobuf_model.net.Proto().op
)
def _update_if_true(cfg, key, value):
if not value:
return
keys = key.split(".")
ref_value = cfg
while len(keys):
ref_value = getattr(ref_value, keys.pop(0))
if ref_value != value:
logger.warning(
"There's conflict between cfg and model, overwrite config {} from {} to {}"
.format(key, ref_value, value)
)
cfg.merge_from_list([key, value])
def update_cfg_from_pb_model(cfg, model):
"""
Update cfg statically based given caffe2 model, in cast that there's conflict
between caffe2 model and the cfg, caffe2 model has higher priority.
"""
with temp_defrost(cfg):
_update_if_true(cfg, "MODEL.MASK_ON", infer_mask_on(model))
_update_if_true(cfg, "MODEL.KEYPOINT_ON", infer_keypoint_on(model))
_update_if_true(cfg, "MODEL.DENSEPOSE_ON", infer_densepose_on(model))
return cfg
def _deprecated_build_caffe2_model(runner, predict_net, init_net):
if hasattr(runner, "_deprecated_build_caffe2_model"):
return runner._deprecated_build_caffe2_model(predict_net, init_net)
pb_model = ProtobufDetectionModel(predict_net, init_net)
pb_model.validate_cfg = partial(update_cfg_from_pb_model, model=pb_model)
return pb_model
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import numpy as np
from mobile_cv.torch.utils_caffe2.ws_utils import ScopedWS
logger = logging.getLogger(__name__)
# NOTE: specific export_to_db for (data, im_info) dual inputs.
# modified from mobile-vision/common/utils/model_utils.py
def export_to_db(net, params, inputs, outputs, out_file, net_type=None, shapes=None):
# NOTE: special handling for im_info: by default the "predict_init_net"
# will zero_fill inputs/outputs (https://fburl.com/diffusion/nvksomrt),
# however the actual value of "im_info" also matters, so we need use
# extra_init_net to handle this.
import numpy as np
from caffe2.python import core
assert len(inputs) == 2
data_name, im_info_name = inputs
data_shape = shapes[data_name] # assume NCHW
extra_init_net = core.Net("extra_init_net")
im_info = np.array(
[[data_shape[2], data_shape[3], 1.0] for _ in range(data_shape[0])],
dtype=np.float32,
)
extra_init_net.GivenTensorFill(
[], im_info_name, shape=shapes[im_info_name], values=im_info
)
from caffe2.caffe2.fb.predictor import predictor_exporter # NOTE: slow import
predictor_export_meta = predictor_exporter.PredictorExportMeta(
predict_net=net,
parameters=params,
inputs=inputs,
outputs=outputs,
net_type=net_type,
shapes=shapes,
extra_init_net=extra_init_net,
)
logger.info("Writing logdb {} ...".format(out_file))
predictor_exporter.save_to_db(
db_type="log_file_db",
db_destination=out_file,
predictor_export_meta=predictor_export_meta,
)
def export_to_logfiledb(predict_net, init_net, outfile, ws_blobs):
logger.info("Exporting Caffe2 model to {}".format(outfile))
shapes = {
b: data.shape if isinstance(data, np.ndarray)
# proivde a dummpy shape if it could not be inferred
else [1]
for b, data in ws_blobs.items()
}
with ScopedWS("__ws_tmp__", is_reset=True) as ws:
ws.RunNetOnce(init_net)
initialized_blobs = set(ws.Blobs())
uninitialized = [
inp for inp in predict_net.external_input if inp not in initialized_blobs
]
params = list(initialized_blobs)
output_names = list(predict_net.external_output)
export_to_db(
predict_net, params, uninitialized, output_names, outfile, shapes=shapes
)
......@@ -112,13 +112,13 @@ class TestD2GoDatasets(unittest.TestCase):
str(x)
for x in [
"D2GO_DATA.DATASETS.COCO_INJECTION.NAMES",
["inj_ds"],
["inj_ds3"],
"D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS",
[image_dir],
"D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES",
[json_file],
"DATASETS.TEST",
("inj_ds",),
("inj_ds3",),
"D2GO_DATA.TEST.MAX_IMAGES",
1,
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment