Commit 2d09c5c3 authored by Yuxin Wu's avatar Yuxin Wu Committed by Facebook GitHub Bot
Browse files

avoid caffe2 imports in d2go/modeling/meta_arch/rcnn.py

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/146

OSSpytorch package removed caffe2. This causes https://github.com/facebookresearch/d2go/issues/137

This diff only makes d2go importable without caffe2. But export is still broken when caffe2 is not available.

Reviewed By: zhanghang1989

Differential Revision: D32690938

fbshipit-source-id: d345687bd720b4b1376494478f1fa44f4c591ccf
parent 51b7be17
......@@ -6,14 +6,8 @@ import logging
import torch
import torch.nn as nn
from caffe2.proto import caffe2_pb2
from d2go.export.api import PredictorExportConfig
from d2go.utils.qat_utils import get_qat_qconfig
from detectron2.export.caffe2_modeling import (
META_ARCH_CAFFE2_EXPORT_TYPE_MAP,
convert_batched_inputs_to_c2_format,
)
from detectron2.export.shared import get_pb_arg_vali, get_pb_arg_vals
from detectron2.modeling import GeneralizedRCNN
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.projects.point_rend import PointRendMaskHead
......@@ -63,6 +57,8 @@ class GeneralizedRCNNPatch:
@RCNN_PREPARE_FOR_EXPORT_REGISTRY.register()
def default_rcnn_prepare_for_export(self, cfg, inputs, predictor_type):
from detectron2.export.caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP
if (
"@c2_ops" in predictor_type
or "caffe2" in predictor_type
......@@ -297,6 +293,10 @@ class D2Caffe2MetaArchPreprocessFunc(object):
self.device = device
def __call__(self, inputs):
from detectron2.export.caffe2_modeling import (
convert_batched_inputs_to_c2_format,
)
data, im_info = convert_batched_inputs_to_c2_format(
inputs, self.size_divisibility, self.device
)
......@@ -304,6 +304,9 @@ class D2Caffe2MetaArchPreprocessFunc(object):
@staticmethod
def get_params(cfg, model):
from caffe2.proto import caffe2_pb2
from detectron2.export.shared import get_pb_arg_vali, get_pb_arg_vals
fake_predict_net = caffe2_pb2.NetDef()
model.encode_additional_info(fake_predict_net, None)
size_divisibility = get_pb_arg_vali(fake_predict_net, "size_divisibility", 0)
......@@ -321,6 +324,10 @@ class D2Caffe2MetaArchPostprocessFunc(object):
self.encoded_info = encoded_info
def __call__(self, inputs, tensor_inputs, tensor_outputs):
from caffe2.proto import caffe2_pb2
from detectron2.export.caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP
from detectron2.export.shared import get_pb_arg_vals
encoded_info = self.encoded_info.encode("ascii")
fake_predict_net = caffe2_pb2.NetDef().FromString(encoded_info)
meta_architecture = get_pb_arg_vals(fake_predict_net, "meta_architecture", None)
......@@ -334,6 +341,8 @@ class D2Caffe2MetaArchPostprocessFunc(object):
@staticmethod
def get_params(cfg, model):
from caffe2.proto import caffe2_pb2
# NOTE: the post processing has different values for different meta
# architectures, here simply relying Caffe2 meta architecture to encode info
# into a NetDef and storing it as whole.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment