Commit 30fb79b6 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add prepare_for_export for D2's SemanticSegmentor

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/83

- Implement `prepare_for_export` for `SemanticSegmentor`
- Add unit test comparing numerical matching

Reviewed By: zhanghang1989

Differential Revision: D29088421

fbshipit-source-id: ccb86ac4b4b90a63eeebdbf76b2bf31c1da65a8b
parent 32dbb035
......@@ -6,7 +6,8 @@ import logging
from functools import lru_cache
from d2go.modeling.meta_arch.rcnn import GeneralizedRCNNPatch
from detectron2.modeling import GeneralizedRCNN
from d2go.modeling.meta_arch.semantic_seg import SemanticSegmentorPatch
from detectron2.modeling import GeneralizedRCNN, SemanticSegmentor
logger = logging.getLogger(__name__)
......@@ -32,4 +33,5 @@ def patch_d2_meta_arch():
_check_and_set(dst_cls, method_name, getattr(src_cls, method_name))
_apply_patch(GeneralizedRCNN, GeneralizedRCNNPatch)
_apply_patch(SemanticSegmentor, SemanticSegmentorPatch)
# TODO: patch other meta-archs defined in D2
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch.nn as nn
from d2go.export.api import PredictorExportConfig
from detectron2.modeling.postprocessing import sem_seg_postprocess
from detectron2.structures import ImageList
from mobile_cv.predictor.api import FuncInfo
class SemanticSegmentorPatch:
METHODS_TO_PATCH = [
"prepare_for_export",
]
def prepare_for_export(self, cfg, inputs, predictor_type):
preprocess_info = FuncInfo.gen_func_info(
PreprocessFunc,
params={
"size_divisibility": self.backbone.size_divisibility,
"device": str(self.device),
},
)
postprocess_info = FuncInfo.gen_func_info(
PostprocessFunc,
params={},
)
preprocess_func = preprocess_info.instantiate()
return PredictorExportConfig(
model=ModelWrapper(self),
data_generator=lambda x: (preprocess_func(x),),
preprocess_info=preprocess_info,
postprocess_info=postprocess_info,
)
class ModelWrapper(nn.Module):
def __init__(self, segmentor):
super().__init__()
self.segmentor = segmentor
def forward(self, x):
x = (x - self.segmentor.pixel_mean) / self.segmentor.pixel_std
features = self.segmentor.backbone(x)
results, losses = self.segmentor.sem_seg_head(features, targets=None)
return results
class PreprocessFunc(object):
def __init__(self, size_divisibility, device):
self.size_divisibility = size_divisibility
self.device = device
def __call__(self, inputs):
images = [x["image"].to(self.device) for x in inputs]
images = ImageList.from_tensors(images, self.size_divisibility)
return images.tensor
class PostprocessFunc(object):
def __call__(self, inputs, tensor_inputs, tensor_outputs):
results = tensor_outputs # nchw
processed_results = []
for result, input_per_image in zip(results, inputs):
height = input_per_image.get("height")
width = input_per_image.get("width")
image_tensor_shape = input_per_image["image"].shape
image_size = (image_tensor_shape[1], image_tensor_shape[2])
r = sem_seg_postprocess(result, image_size, height, width)
processed_results.append({"sem_seg": r})
return processed_results
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import shutil
import tempfile
import unittest
import torch
from d2go.export.api import convert_and_export_predictor
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.runner import Detectron2GoRunner
from mobile_cv.predictor.api import create_predictor
# Add APIs to D2's meta arch, this is usually called in runner's setup, however in
# unittest it needs to be called sperarately. (maybe we should apply this by default)
patch_d2_meta_arch()
def _get_batch(height, width, is_train):
def _get_frame():
random_image = torch.rand(3, height, width).to(torch.float32)
ret = {"image": random_image}
if is_train:
mask_size = (height, width)
random_mask = torch.randint(low=0, high=2, size=mask_size).to(torch.int64)
ret["sem_seg"] = random_mask
return ret
batch_size = 2 if is_train else 1
return [
{"filename": "some_file", "width": 100, "height": 100, **_get_frame()}
for _ in range(batch_size)
]
def _get_data_loader(height, width, is_train):
inputs = _get_batch(height, width, is_train)
def get_data_loader():
while True:
yield inputs
return get_data_loader()
def _get_input_dim(model):
h = w = max(model.backbone.size_divisibility, 1)
return h, w
class BaseSemanticSegTestCase:
class TemplateTestCase(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp(prefix="test_meta_arch_semantic_seg_")
self.addCleanup(shutil.rmtree, self.test_dir)
runner = Detectron2GoRunner()
self.cfg = runner.get_default_cfg()
self.setup_custom_test()
self.cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
self.test_model = runner.build_model(self.cfg, eval_only=True)
def setup_custom_test(self):
raise NotImplementedError()
def test_inference(self):
h, w = _get_input_dim(self.test_model)
inputs = _get_batch(h, w, False)
with torch.no_grad():
self.test_model(inputs)
def test_train(self):
h, w = _get_input_dim(self.test_model)
inputs = _get_batch(h, w, True)
self.test_model.train()
loss_dict = self.test_model(inputs)
losses = sum(loss_dict.values())
losses.backward()
def _test_export(self, predictor_type, compare_match=True):
h, w = _get_input_dim(self.test_model)
dl = _get_data_loader(h, w, False)
inputs = next(iter(dl))
output_dir = os.path.join(self.test_dir, "test_export")
predictor_path = convert_and_export_predictor(
self.cfg, self.test_model, predictor_type, output_dir, dl
)
predictor = create_predictor(predictor_path)
predicotr_outputs = predictor(inputs)
self.assertEqual(len(predicotr_outputs), len(inputs))
with torch.no_grad():
pytorch_outputs = self.test_model(inputs)
self.assertEqual(len(pytorch_outputs), len(inputs))
if compare_match:
for predictor_output, pytorch_output in zip(
predicotr_outputs, pytorch_outputs
):
torch.testing.assert_allclose(
predictor_output["sem_seg"], pytorch_output["sem_seg"]
)
class TestR50FPN(BaseSemanticSegTestCase.TemplateTestCase):
def setup_custom_test(self):
self.cfg.merge_from_file("detectron2://Misc/semantic_R_50_FPN_1x.yaml")
def test_export_torchscript(self):
self._test_export("torchscript", compare_match=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment