Commit 5aadaaa4 authored by Hang Zhang's avatar Hang Zhang Committed by Facebook GitHub Bot
Browse files

DETR Model Export

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/169

Make d2go DETR exportable (torchscript compatible)
Move generating masks to preprocessing

Reviewed By: sstsai-adl

Differential Revision: D33798073

fbshipit-source-id: d629b0c9cbdb67060982be717c7138a0e7e9adbc
parent 6791682f
...@@ -124,8 +124,10 @@ class Detr(nn.Module): ...@@ -124,8 +124,10 @@ class Detr(nn.Module):
dict[str: Tensor]: dict[str: Tensor]:
mapping from a named loss to a tensor storing the loss. Used during training only. mapping from a named loss to a tensor storing the loss. Used during training only.
""" """
images = self.preprocess_image(batched_inputs) images_lists = self.preprocess_image(batched_inputs)
output = self.detr(images) # convert images_lists to Nested Tensor?
nested_images = self.imagelist_to_nestedtensor(images_lists)
output = self.detr(nested_images)
if self.training: if self.training:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
...@@ -144,10 +146,12 @@ class Detr(nn.Module): ...@@ -144,10 +146,12 @@ class Detr(nn.Module):
box_cls = output["pred_logits"] box_cls = output["pred_logits"]
box_pred = output["pred_boxes"] box_pred = output["pred_boxes"]
mask_pred = output["pred_masks"] if self.mask_on else None mask_pred = output["pred_masks"] if self.mask_on else None
results = self.inference(box_cls, box_pred, mask_pred, images.image_sizes) results = self.inference(
box_cls, box_pred, mask_pred, images_lists.image_sizes
)
processed_results = [] processed_results = []
for results_per_image, input_per_image, image_size in zip( for results_per_image, input_per_image, image_size in zip(
results, batched_inputs, images.image_sizes results, batched_inputs, images_lists.image_sizes
): ):
height = input_per_image.get("height", image_size[0]) height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1]) width = input_per_image.get("width", image_size[1])
...@@ -239,3 +243,12 @@ class Detr(nn.Module): ...@@ -239,3 +243,12 @@ class Detr(nn.Module):
images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs] images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs]
images = ImageList.from_tensors(images) images = ImageList.from_tensors(images)
return images return images
def imagelist_to_nestedtensor(self, images):
tensor = images.tensor
device = tensor.device
N, _, H, W = tensor.shape
masks = torch.ones((N, H, W), dtype=torch.bool, device=device)
for idx, (h, w) in enumerate(images.image_sizes):
masks[idx, :h, :w] = False
return NestedTensor(tensor, masks)
from typing import Dict
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from detectron2.modeling import build_backbone from detectron2.modeling import build_backbone
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
from detr.models.backbone import Joiner from detr.models.backbone import Joiner
...@@ -59,35 +62,15 @@ class ResNetMaskedBackbone(nn.Module): ...@@ -59,35 +62,15 @@ class ResNetMaskedBackbone(nn.Module):
self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()] self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()]
self.num_channels = [backbone_shape[k].channels for k in backbone_shape.keys()] self.num_channels = [backbone_shape[k].channels for k in backbone_shape.keys()]
def forward(self, images): def forward(self, tensor_list: NestedTensor):
features = self.backbone(images.tensor) xs = self.backbone(tensor_list.tensors)
# one tensor per feature level. Each tensor has shape (B, maxH, maxW) out: Dict[str, NestedTensor] = {}
masks = self.mask_out_padding( for name, x in xs.items():
[features_per_level.shape for features_per_level in features.values()], m = tensor_list.mask
images.image_sizes, assert m is not None
images.tensor.device, mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
) out[name] = NestedTensor(x, mask)
assert len(features) == len(masks) return out
for i, k in enumerate(features.keys()):
features[k] = NestedTensor(features[k], masks[i])
return features
def mask_out_padding(self, feature_shapes, image_sizes, device):
masks = []
assert len(feature_shapes) == len(self.feature_strides)
for idx, shape in enumerate(feature_shapes):
N, _, H, W = shape
masks_per_feature_level = torch.ones(
(N, H, W), dtype=torch.bool, device=device
)
for img_idx, (h, w) in enumerate(image_sizes):
masks_per_feature_level[
img_idx,
: int(np.ceil(float(h) / self.feature_strides[idx])),
: int(np.ceil(float(w) / self.feature_strides[idx])),
] = 0
masks.append(masks_per_feature_level)
return masks
class FBNetMaskedBackbone(ResNetMaskedBackbone): class FBNetMaskedBackbone(ResNetMaskedBackbone):
...@@ -105,20 +88,6 @@ class FBNetMaskedBackbone(ResNetMaskedBackbone): ...@@ -105,20 +88,6 @@ class FBNetMaskedBackbone(ResNetMaskedBackbone):
self.backbone._out_feature_strides[k] for k in self.out_features self.backbone._out_feature_strides[k] for k in self.out_features
] ]
def forward(self, images):
features = self.backbone(images.tensor)
masks = self.mask_out_padding(
[features_per_level.shape for features_per_level in features.values()],
images.image_sizes,
images.tensor.device,
)
assert len(features) == len(masks)
ret_features = {}
for i, k in enumerate(features.keys()):
if k in self.out_features:
ret_features[k] = NestedTensor(features[k], masks[i])
return ret_features
class SimpleSingleStageBackbone(ResNetMaskedBackbone): class SimpleSingleStageBackbone(ResNetMaskedBackbone):
"""This is a simple wrapper for single stage backbone, """This is a simple wrapper for single stage backbone,
...@@ -135,15 +104,3 @@ class SimpleSingleStageBackbone(ResNetMaskedBackbone): ...@@ -135,15 +104,3 @@ class SimpleSingleStageBackbone(ResNetMaskedBackbone):
self.feature_strides = [cfg.MODEL.BACKBONE.STRIDE] self.feature_strides = [cfg.MODEL.BACKBONE.STRIDE]
self.num_channels = [cfg.MODEL.BACKBONE.CHANNEL] self.num_channels = [cfg.MODEL.BACKBONE.CHANNEL]
self.strides = [cfg.MODEL.BACKBONE.STRIDE] self.strides = [cfg.MODEL.BACKBONE.STRIDE]
def forward(self, images):
y = self.backbone(images.tensor)
masks = self.mask_out_padding(
[y.shape],
images.image_sizes,
images.tensor.device,
)
assert len(masks) == 1
ret_features = {}
ret_features[self.out_features[0]] = NestedTensor(y, masks[0])
return ret_features
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import io import io
import unittest import unittest
from typing import List from typing import List
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import unittest
import torch
from d2go.runner import create_runner
from detr.util.misc import nested_tensor_from_tensor_list
from fvcore.nn import flop_count_table, FlopCountAnalysis
class Tester(unittest.TestCase):
@staticmethod
def _set_detr_cfg(cfg, enc_layers, dec_layers, num_queries, dim_feedforward):
cfg.MODEL.META_ARCHITECTURE = "Detr"
cfg.MODEL.DETR.NUM_OBJECT_QUERIES = num_queries
cfg.MODEL.DETR.ENC_LAYERS = enc_layers
cfg.MODEL.DETR.DEC_LAYERS = dec_layers
cfg.MODEL.DETR.DEEP_SUPERVISION = False
cfg.MODEL.DETR.DIM_FEEDFORWARD = dim_feedforward # 2048
def _assert_model_output(self, model, scripted_model):
x = nested_tensor_from_tensor_list(
[torch.rand(3, 200, 200), torch.rand(3, 200, 250)]
)
out = model(x)
out_script = scripted_model(x)
self.assertTrue(out["pred_logits"].equal(out_script["pred_logits"]))
self.assertTrue(out["pred_boxes"].equal(out_script["pred_boxes"]))
def test_detr_res50_export(self):
runner = create_runner("d2go.projects.detr.runner.DETRRunner")
cfg = runner.get_default_cfg()
cfg.MODEL.DEVICE = "cpu"
# DETR
self._set_detr_cfg(cfg, 6, 6, 100, 2048)
# backbone
cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
cfg.MODEL.RESNETS.DEPTH = 50
cfg.MODEL.RESNETS.STRIDE_IN_1X1 = False
cfg.MODEL.RESNETS.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
# build model
model = runner.build_model(cfg).eval()
model = model.detr
scripted_model = torch.jit.script(model)
self._assert_model_output(model, scripted_model)
def test_detr_fbnet_export(self):
runner = create_runner("d2go.projects.detr.runner.DETRRunner")
cfg = runner.get_default_cfg()
cfg.MODEL.DEVICE = "cpu"
# DETR
self._set_detr_cfg(cfg, 3, 3, 50, 256)
# backbone
cfg.MODEL.BACKBONE.NAME = "FBNetV2C4Backbone"
cfg.MODEL.FBNET_V2.ARCH = "FBNetV3_A_dsmask_C5"
cfg.MODEL.FBNET_V2.WIDTH_DIVISOR = 8
cfg.MODEL.FBNET_V2.OUT_FEATURES = ["trunk4"]
# build model
model = runner.build_model(cfg).eval()
model = model.detr
print(model)
scripted_model = torch.jit.script(model)
self._assert_model_output(model, scripted_model)
# print flops
table = flop_count_table(FlopCountAnalysis(model, ([torch.rand(3, 224, 320)],)))
print(table)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment