Commit 731d98f9 authored by Owen Wang's avatar Owen Wang Committed by Facebook GitHub Bot
Browse files

d2go/semantic_seg Pre- and PostprocessFunc

Summary: Add documentation on the pre and post processing functions for segmentation.

Reviewed By: XiaoliangDai

Differential Revision: D34882165

fbshipit-source-id: 375c62d0ad632a40b6557065b3362e333df8c55f
parent 4f651f97
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, Dict, List
import torch
import torch.nn as nn import torch.nn as nn
from d2go.export.api import PredictorExportConfig from d2go.export.api import PredictorExportConfig
from detectron2.modeling.postprocessing import sem_seg_postprocess from detectron2.modeling.postprocessing import sem_seg_postprocess
...@@ -49,27 +52,66 @@ class ModelWrapper(nn.Module): ...@@ -49,27 +52,66 @@ class ModelWrapper(nn.Module):
class PreprocessFunc(object): class PreprocessFunc(object):
"""
A common preprocessing module for semantic segmentation models.
"""
def __init__(self, size_divisibility, device): def __init__(self, size_divisibility, device):
self.size_divisibility = size_divisibility self.size_divisibility = size_divisibility
self.device = device self.device = device
def __call__(self, inputs): def __call__(self, batched_inputs: List[Dict[str, Any]]) -> torch.Tensor:
images = [x["image"].to(self.device) for x in inputs] """
Retreive image tensor from dataloader batches.
Args:
batched_inputs: (List[Dict[str, Tensor]]): output from a
D2Go train or test data loader.
Returns:
input images (torch.Tensor): ImageList-wrapped NCHW tensor
(i.e. with padding and divisibility alignment) of batches' images.
"""
images = [x["image"].to(self.device) for x in batched_inputs]
images = ImageList.from_tensors(images, self.size_divisibility) images = ImageList.from_tensors(images, self.size_divisibility)
return images.tensor return images.tensor
class PostprocessFunc(object): class PostprocessFunc(object):
def __call__(self, inputs, tensor_inputs, tensor_outputs): """
A common postprocessing module for semantic segmentation models.
"""
def __call__(
self,
batched_inputs: List[Dict[str, Any]],
tensor_inputs: torch.Tensor,
tensor_outputs: torch.Tensor,
) -> List[Dict[str, Any]]:
"""
Rescales sem_seg logits to original image input resolution,
and packages the logits into D2Go's expected output format.
Args:
inputs (List[Dict[str, Tensor]]): batched inputs from the dataloader.
tensor_inputs (Tensor): tensorized inputs, e.g. from `PreprocessFunc`.
tensor_outputs (Tensor): sem seg logits tensor from the model to process.
Returns:
processed_results (List[Dict]): List of D2Go output dicts ready to be used
downstream in an Evaluator, for export, etc.
"""
results = tensor_outputs # nchw results = tensor_outputs # nchw
processed_results = [] processed_results = []
for result, input_per_image in zip(results, inputs): for result, input_per_image in zip(results, batched_inputs):
height = input_per_image.get("height") height = input_per_image.get("height")
width = input_per_image.get("width") width = input_per_image.get("width")
image_tensor_shape = input_per_image["image"].shape image_tensor_shape = input_per_image["image"].shape
image_size = (image_tensor_shape[1], image_tensor_shape[2]) image_size = (image_tensor_shape[1], image_tensor_shape[2])
# D2's sem_seg_postprocess rescales sem seg masks to the
# provided original input resolution.
r = sem_seg_postprocess(result, image_size, height, width) r = sem_seg_postprocess(result, image_size, height, width)
processed_results.append({"sem_seg": r}) processed_results.append({"sem_seg": r})
return processed_results return processed_results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment