Commit 731d98f9 authored by Owen Wang's avatar Owen Wang Committed by Facebook GitHub Bot
Browse files

d2go/semantic_seg Pre- and PostprocessFunc

Summary: Add documentation on the pre and post processing functions for segmentation.

Reviewed By: XiaoliangDai

Differential Revision: D34882165

fbshipit-source-id: 375c62d0ad632a40b6557065b3362e333df8c55f
parent 4f651f97
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, Dict, List
import torch
import torch.nn as nn
from d2go.export.api import PredictorExportConfig
from detectron2.modeling.postprocessing import sem_seg_postprocess
......@@ -49,27 +52,66 @@ class ModelWrapper(nn.Module):
class PreprocessFunc(object):
"""
A common preprocessing module for semantic segmentation models.
"""
def __init__(self, size_divisibility, device):
self.size_divisibility = size_divisibility
self.device = device
def __call__(self, inputs):
images = [x["image"].to(self.device) for x in inputs]
def __call__(self, batched_inputs: List[Dict[str, Any]]) -> torch.Tensor:
"""
Retreive image tensor from dataloader batches.
Args:
batched_inputs: (List[Dict[str, Tensor]]): output from a
D2Go train or test data loader.
Returns:
input images (torch.Tensor): ImageList-wrapped NCHW tensor
(i.e. with padding and divisibility alignment) of batches' images.
"""
images = [x["image"].to(self.device) for x in batched_inputs]
images = ImageList.from_tensors(images, self.size_divisibility)
return images.tensor
class PostprocessFunc(object):
def __call__(self, inputs, tensor_inputs, tensor_outputs):
"""
A common postprocessing module for semantic segmentation models.
"""
def __call__(
self,
batched_inputs: List[Dict[str, Any]],
tensor_inputs: torch.Tensor,
tensor_outputs: torch.Tensor,
) -> List[Dict[str, Any]]:
"""
Rescales sem_seg logits to original image input resolution,
and packages the logits into D2Go's expected output format.
Args:
inputs (List[Dict[str, Tensor]]): batched inputs from the dataloader.
tensor_inputs (Tensor): tensorized inputs, e.g. from `PreprocessFunc`.
tensor_outputs (Tensor): sem seg logits tensor from the model to process.
Returns:
processed_results (List[Dict]): List of D2Go output dicts ready to be used
downstream in an Evaluator, for export, etc.
"""
results = tensor_outputs # nchw
processed_results = []
for result, input_per_image in zip(results, inputs):
for result, input_per_image in zip(results, batched_inputs):
height = input_per_image.get("height")
width = input_per_image.get("width")
image_tensor_shape = input_per_image["image"].shape
image_size = (image_tensor_shape[1], image_tensor_shape[2])
# D2's sem_seg_postprocess rescales sem seg masks to the
# provided original input resolution.
r = sem_seg_postprocess(result, image_size, height, width)
processed_results.append({"sem_seg": r})
return processed_results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment