Commit 3144257c authored by mashun1's avatar mashun1
Browse files

catvton

parents
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
import torch
from torch import nn
from detectron2.config import CfgNode
from detectron2.layers import ConvTranspose2d, interpolate
from ...structures import DensePoseEmbeddingPredictorOutput
from ..utils import initialize_module_params
from .registry import DENSEPOSE_PREDICTOR_REGISTRY
@DENSEPOSE_PREDICTOR_REGISTRY.register()
class DensePoseEmbeddingPredictor(nn.Module):
"""
Last layers of a DensePose model that take DensePose head outputs as an input
and produce model outputs for continuous surface embeddings (CSE).
"""
def __init__(self, cfg: CfgNode, input_channels: int):
"""
Initialize predictor using configuration options
Args:
cfg (CfgNode): configuration options
input_channels (int): input tensor size along the channel dimension
"""
super().__init__()
dim_in = input_channels
n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS
embed_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL
# coarse segmentation
self.coarse_segm_lowres = ConvTranspose2d(
dim_in, n_segm_chan, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
# embedding
self.embed_lowres = ConvTranspose2d(
dim_in, embed_size, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
self.scale_factor = cfg.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE
initialize_module_params(self)
def interp2d(self, tensor_nchw: torch.Tensor):
"""
Bilinear interpolation method to be used for upscaling
Args:
tensor_nchw (tensor): tensor of shape (N, C, H, W)
Return:
tensor of shape (N, C, Hout, Wout), where Hout and Wout are computed
by applying the scale factor to H and W
"""
return interpolate(
tensor_nchw, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
)
def forward(self, head_outputs):
"""
Perform forward step on DensePose head outputs
Args:
head_outputs (tensor): DensePose head outputs, tensor of shape [N, D, H, W]
"""
embed_lowres = self.embed_lowres(head_outputs)
coarse_segm_lowres = self.coarse_segm_lowres(head_outputs)
embed = self.interp2d(embed_lowres)
coarse_segm = self.interp2d(coarse_segm_lowres)
return DensePoseEmbeddingPredictorOutput(embedding=embed, coarse_segm=coarse_segm)
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from typing import Any
import torch
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.layers import ConvTranspose2d
from densepose.modeling.confidence import DensePoseConfidenceModelConfig
from densepose.modeling.utils import initialize_module_params
from densepose.structures import decorate_cse_predictor_output_class_with_confidences
class DensePoseEmbeddingConfidencePredictorMixin:
"""
Predictor contains the last layers of a DensePose model that take DensePose head
outputs as an input and produce model outputs. Confidence predictor mixin is used
to generate confidences for coarse segmentation estimated by some
base predictor. Several assumptions need to hold for the base predictor:
1) the `forward` method must return CSE DensePose head outputs,
tensor of shape [N, D, H, W]
2) `interp2d` method must be defined to perform bilinear interpolation;
the same method is typically used for masks and confidences
Confidence predictor mixin provides confidence estimates, as described in:
N. Neverova et al., Correlated Uncertainty for Learning Dense Correspondences
from Noisy Labels, NeurIPS 2019
A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020
"""
def __init__(self, cfg: CfgNode, input_channels: int):
"""
Initialize confidence predictor using configuration options.
Args:
cfg (CfgNode): configuration options
input_channels (int): number of input channels
"""
# we rely on base predictor to call nn.Module.__init__
super().__init__(cfg, input_channels) # pyre-ignore[19]
self.confidence_model_cfg = DensePoseConfidenceModelConfig.from_cfg(cfg)
self._initialize_confidence_estimation_layers(cfg, input_channels)
self._registry = {}
initialize_module_params(self) # pyre-ignore[6]
def _initialize_confidence_estimation_layers(self, cfg: CfgNode, dim_in: int):
"""
Initialize confidence estimation layers based on configuration options
Args:
cfg (CfgNode): configuration options
dim_in (int): number of input channels
"""
kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL
if self.confidence_model_cfg.segm_confidence.enabled:
self.coarse_segm_confidence_lowres = ConvTranspose2d( # pyre-ignore[16]
dim_in, 1, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
def forward(self, head_outputs: torch.Tensor):
"""
Perform forward operation on head outputs used as inputs for the predictor.
Calls forward method from the base predictor and uses its outputs to compute
confidences.
Args:
head_outputs (Tensor): head outputs used as predictor inputs
Return:
An instance of outputs with confidences,
see `decorate_cse_predictor_output_class_with_confidences`
"""
# assuming base class returns SIUV estimates in its first result
base_predictor_outputs = super().forward(head_outputs) # pyre-ignore[16]
# create output instance by extending base predictor outputs:
output = self._create_output_instance(base_predictor_outputs)
if self.confidence_model_cfg.segm_confidence.enabled:
# base predictor outputs are assumed to have `coarse_segm` attribute
# base predictor is assumed to define `interp2d` method for bilinear interpolation
output.coarse_segm_confidence = (
F.softplus(
self.interp2d( # pyre-ignore[16]
self.coarse_segm_confidence_lowres(head_outputs) # pyre-ignore[16]
)
)
+ self.confidence_model_cfg.segm_confidence.epsilon
)
output.coarse_segm = base_predictor_outputs.coarse_segm * torch.repeat_interleave(
output.coarse_segm_confidence, base_predictor_outputs.coarse_segm.shape[1], dim=1
)
return output
def _create_output_instance(self, base_predictor_outputs: Any):
"""
Create an instance of predictor outputs by copying the outputs from the
base predictor and initializing confidence
Args:
base_predictor_outputs: an instance of base predictor outputs
(the outputs type is assumed to be a dataclass)
Return:
An instance of outputs with confidences
"""
PredictorOutput = decorate_cse_predictor_output_class_with_confidences(
type(base_predictor_outputs) # pyre-ignore[6]
)
# base_predictor_outputs is assumed to be a dataclass
# reassign all the fields from base_predictor_outputs (no deep copy!), add new fields
output = PredictorOutput(
**base_predictor_outputs.__dict__,
coarse_segm_confidence=None,
)
return output
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from . import DensePoseEmbeddingConfidencePredictorMixin, DensePoseEmbeddingPredictor
from .registry import DENSEPOSE_PREDICTOR_REGISTRY
@DENSEPOSE_PREDICTOR_REGISTRY.register()
class DensePoseEmbeddingWithConfidencePredictor(
DensePoseEmbeddingConfidencePredictorMixin, DensePoseEmbeddingPredictor
):
"""
Predictor that combines CSE and CSE confidence estimation
"""
pass
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from detectron2.utils.registry import Registry
DENSEPOSE_PREDICTOR_REGISTRY = Registry("DENSEPOSE_PREDICTOR")
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from .v1convx import DensePoseV1ConvXHead
from .deeplab import DensePoseDeepLabHead
from .registry import ROI_DENSEPOSE_HEAD_REGISTRY
from .roi_head import Decoder, DensePoseROIHeads
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.layers import Conv2d
from .registry import ROI_DENSEPOSE_HEAD_REGISTRY
@ROI_DENSEPOSE_HEAD_REGISTRY.register()
class DensePoseDeepLabHead(nn.Module):
"""
DensePose head using DeepLabV3 model from
"Rethinking Atrous Convolution for Semantic Image Segmentation"
<https://arxiv.org/abs/1706.05587>.
"""
def __init__(self, cfg: CfgNode, input_channels: int):
super(DensePoseDeepLabHead, self).__init__()
# fmt: off
hidden_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM
kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL
norm = cfg.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM
self.n_stacked_convs = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS
self.use_nonlocal = cfg.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON
# fmt: on
pad_size = kernel_size // 2
n_channels = input_channels
self.ASPP = ASPP(input_channels, [6, 12, 56], n_channels) # 6, 12, 56
self.add_module("ASPP", self.ASPP)
if self.use_nonlocal:
self.NLBlock = NONLocalBlock2D(input_channels, bn_layer=True)
self.add_module("NLBlock", self.NLBlock)
# weight_init.c2_msra_fill(self.ASPP)
for i in range(self.n_stacked_convs):
norm_module = nn.GroupNorm(32, hidden_dim) if norm == "GN" else None
layer = Conv2d(
n_channels,
hidden_dim,
kernel_size,
stride=1,
padding=pad_size,
bias=not norm,
norm=norm_module,
)
weight_init.c2_msra_fill(layer)
n_channels = hidden_dim
layer_name = self._get_layer_name(i)
self.add_module(layer_name, layer)
self.n_out_channels = hidden_dim
# initialize_module_params(self)
def forward(self, features):
x0 = features
x = self.ASPP(x0)
if self.use_nonlocal:
x = self.NLBlock(x)
output = x
for i in range(self.n_stacked_convs):
layer_name = self._get_layer_name(i)
x = getattr(self, layer_name)(x)
x = F.relu(x)
output = x
return output
def _get_layer_name(self, i: int):
layer_name = "body_conv_fcn{}".format(i + 1)
return layer_name
# Copied from
# https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/deeplabv3.py
# See https://arxiv.org/pdf/1706.05587.pdf for details
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(
in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False
),
nn.GroupNorm(32, out_channels),
nn.ReLU(),
]
super(ASPPConv, self).__init__(*modules)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.GroupNorm(32, out_channels),
nn.ReLU(),
)
def forward(self, x):
size = x.shape[-2:]
x = super(ASPPPooling, self).forward(x)
return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels):
super(ASPP, self).__init__()
modules = []
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.GroupNorm(32, out_channels),
nn.ReLU(),
)
)
rate1, rate2, rate3 = tuple(atrous_rates)
modules.append(ASPPConv(in_channels, out_channels, rate1))
modules.append(ASPPConv(in_channels, out_channels, rate2))
modules.append(ASPPConv(in_channels, out_channels, rate3))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
# nn.BatchNorm2d(out_channels),
nn.ReLU(),
# nn.Dropout(0.5)
)
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)
# copied from
# https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_embedded_gaussian.py
# See https://arxiv.org/abs/1711.07971 for details
class _NonLocalBlockND(nn.Module):
def __init__(
self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True
):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.GroupNorm # (32, hidden_dim) #nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.GroupNorm # (32, hidden_dim)nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=2)
bn = nn.GroupNorm # (32, hidden_dim)nn.BatchNorm1d
self.g = conv_nd(
in_channels=self.in_channels,
out_channels=self.inter_channels,
kernel_size=1,
stride=1,
padding=0,
)
if bn_layer:
self.W = nn.Sequential(
conv_nd(
in_channels=self.inter_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
padding=0,
),
bn(32, self.in_channels),
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(
in_channels=self.inter_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
padding=0,
)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(
in_channels=self.in_channels,
out_channels=self.inter_channels,
kernel_size=1,
stride=1,
padding=0,
)
self.phi = conv_nd(
in_channels=self.in_channels,
out_channels=self.inter_channels,
kernel_size=1,
stride=1,
padding=0,
)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x):
"""
:param x: (b, c, t, h, w)
:return:
"""
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock2D, self).__init__(
in_channels,
inter_channels=inter_channels,
dimension=2,
sub_sample=sub_sample,
bn_layer=bn_layer,
)
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from detectron2.utils.registry import Registry
ROI_DENSEPOSE_HEAD_REGISTRY = Registry("ROI_DENSEPOSE_HEAD")
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import numpy as np
from typing import Dict, List, Optional
import fvcore.nn.weight_init as weight_init
import torch
import torch.nn as nn
from torch.nn import functional as F
from detectron2.layers import Conv2d, ShapeSpec, get_norm
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.modeling.poolers import ROIPooler
from detectron2.modeling.roi_heads import select_foreground_proposals
from detectron2.structures import ImageList, Instances
from .. import (
build_densepose_data_filter,
build_densepose_embedder,
build_densepose_head,
build_densepose_losses,
build_densepose_predictor,
densepose_inference,
)
class Decoder(nn.Module):
"""
A semantic segmentation head described in detail in the Panoptic Feature Pyramid Networks paper
(https://arxiv.org/abs/1901.02446). It takes FPN features as input and merges information from
all levels of the FPN into single output.
"""
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec], in_features):
super(Decoder, self).__init__()
# fmt: off
self.in_features = in_features
feature_strides = {k: v.stride for k, v in input_shape.items()}
feature_channels = {k: v.channels for k, v in input_shape.items()}
num_classes = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NUM_CLASSES
conv_dims = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECODER_CONV_DIMS
self.common_stride = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECODER_COMMON_STRIDE
norm = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NORM
# fmt: on
self.scale_heads = []
for in_feature in self.in_features:
head_ops = []
head_length = max(
1, int(np.log2(feature_strides[in_feature]) - np.log2(self.common_stride))
)
for k in range(head_length):
conv = Conv2d(
feature_channels[in_feature] if k == 0 else conv_dims,
conv_dims,
kernel_size=3,
stride=1,
padding=1,
bias=not norm,
norm=get_norm(norm, conv_dims),
activation=F.relu,
)
weight_init.c2_msra_fill(conv)
head_ops.append(conv)
if feature_strides[in_feature] != self.common_stride:
head_ops.append(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
)
self.scale_heads.append(nn.Sequential(*head_ops))
self.add_module(in_feature, self.scale_heads[-1])
self.predictor = Conv2d(conv_dims, num_classes, kernel_size=1, stride=1, padding=0)
weight_init.c2_msra_fill(self.predictor)
def forward(self, features: List[torch.Tensor]):
for i, _ in enumerate(self.in_features):
if i == 0:
x = self.scale_heads[i](features[i])
else:
x = x + self.scale_heads[i](features[i])
x = self.predictor(x)
return x
@ROI_HEADS_REGISTRY.register()
class DensePoseROIHeads(StandardROIHeads):
"""
A Standard ROIHeads which contains an addition of DensePose head.
"""
def __init__(self, cfg, input_shape):
super().__init__(cfg, input_shape)
self._init_densepose_head(cfg, input_shape)
def _init_densepose_head(self, cfg, input_shape):
# fmt: off
self.densepose_on = cfg.MODEL.DENSEPOSE_ON
if not self.densepose_on:
return
self.densepose_data_filter = build_densepose_data_filter(cfg)
dp_pooler_resolution = cfg.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION
dp_pooler_sampling_ratio = cfg.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO
dp_pooler_type = cfg.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE
self.use_decoder = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECODER_ON
# fmt: on
if self.use_decoder:
dp_pooler_scales = (1.0 / input_shape[self.in_features[0]].stride,)
else:
dp_pooler_scales = tuple(1.0 / input_shape[k].stride for k in self.in_features)
in_channels = [input_shape[f].channels for f in self.in_features][0]
if self.use_decoder:
self.decoder = Decoder(cfg, input_shape, self.in_features)
self.densepose_pooler = ROIPooler(
output_size=dp_pooler_resolution,
scales=dp_pooler_scales,
sampling_ratio=dp_pooler_sampling_ratio,
pooler_type=dp_pooler_type,
)
self.densepose_head = build_densepose_head(cfg, in_channels)
self.densepose_predictor = build_densepose_predictor(
cfg, self.densepose_head.n_out_channels
)
self.densepose_losses = build_densepose_losses(cfg)
self.embedder = build_densepose_embedder(cfg)
def _forward_densepose(self, features: Dict[str, torch.Tensor], instances: List[Instances]):
"""
Forward logic of the densepose prediction branch.
Args:
features (dict[str, Tensor]): input data as a mapping from feature
map name to tensor. Axis 0 represents the number of images `N` in
the input data; axes 1-3 are channels, height, and width, which may
vary between feature maps (e.g., if a feature pyramid is used).
instances (list[Instances]): length `N` list of `Instances`. The i-th
`Instances` contains instances for the i-th input image,
In training, they can be the proposals.
In inference, they can be the predicted boxes.
Returns:
In training, a dict of losses.
In inference, update `instances` with new fields "densepose" and return it.
"""
if not self.densepose_on:
return {} if self.training else instances
features_list = [features[f] for f in self.in_features]
if self.training:
proposals, _ = select_foreground_proposals(instances, self.num_classes)
features_list, proposals = self.densepose_data_filter(features_list, proposals)
if len(proposals) > 0:
proposal_boxes = [x.proposal_boxes for x in proposals]
if self.use_decoder:
features_list = [self.decoder(features_list)]
features_dp = self.densepose_pooler(features_list, proposal_boxes)
densepose_head_outputs = self.densepose_head(features_dp)
densepose_predictor_outputs = self.densepose_predictor(densepose_head_outputs)
densepose_loss_dict = self.densepose_losses(
proposals, densepose_predictor_outputs, embedder=self.embedder
)
return densepose_loss_dict
else:
pred_boxes = [x.pred_boxes for x in instances]
if self.use_decoder:
features_list = [self.decoder(features_list)]
features_dp = self.densepose_pooler(features_list, pred_boxes)
if len(features_dp) > 0:
densepose_head_outputs = self.densepose_head(features_dp)
densepose_predictor_outputs = self.densepose_predictor(densepose_head_outputs)
else:
densepose_predictor_outputs = None
densepose_inference(densepose_predictor_outputs, instances)
return instances
def forward(
self,
images: ImageList,
features: Dict[str, torch.Tensor],
proposals: List[Instances],
targets: Optional[List[Instances]] = None,
):
instances, losses = super().forward(images, features, proposals, targets)
del targets, images
if self.training:
losses.update(self._forward_densepose(features, instances))
return instances, losses
def forward_with_given_boxes(
self, features: Dict[str, torch.Tensor], instances: List[Instances]
):
"""
Use the given boxes in `instances` to produce other (non-box) per-ROI outputs.
This is useful for downstream tasks where a box is known, but need to obtain
other attributes (outputs of other heads).
Test-time augmentation also uses this.
Args:
features: same as in `forward()`
instances (list[Instances]): instances to predict other outputs. Expect the keys
"pred_boxes" and "pred_classes" to exist.
Returns:
instances (list[Instances]):
the same `Instances` objects, with extra
fields such as `pred_masks` or `pred_keypoints`.
"""
instances = super().forward_with_given_boxes(features, instances)
instances = self._forward_densepose(features, instances)
return instances
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.layers import Conv2d
from ..utils import initialize_module_params
from .registry import ROI_DENSEPOSE_HEAD_REGISTRY
@ROI_DENSEPOSE_HEAD_REGISTRY.register()
class DensePoseV1ConvXHead(nn.Module):
"""
Fully convolutional DensePose head.
"""
def __init__(self, cfg: CfgNode, input_channels: int):
"""
Initialize DensePose fully convolutional head
Args:
cfg (CfgNode): configuration options
input_channels (int): number of input channels
"""
super(DensePoseV1ConvXHead, self).__init__()
# fmt: off
hidden_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM
kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL
self.n_stacked_convs = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS
# fmt: on
pad_size = kernel_size // 2
n_channels = input_channels
for i in range(self.n_stacked_convs):
layer = Conv2d(n_channels, hidden_dim, kernel_size, stride=1, padding=pad_size)
layer_name = self._get_layer_name(i)
self.add_module(layer_name, layer)
n_channels = hidden_dim
self.n_out_channels = n_channels
initialize_module_params(self)
def forward(self, features: torch.Tensor):
"""
Apply DensePose fully convolutional head to the input features
Args:
features (tensor): input features
Result:
A tensor of DensePose head outputs
"""
x = features
output = x
for i in range(self.n_stacked_convs):
layer_name = self._get_layer_name(i)
x = getattr(self, layer_name)(x)
x = F.relu(x)
output = x
return output
def _get_layer_name(self, i: int):
layer_name = "body_conv_fcn{}".format(i + 1)
return layer_name
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import copy
import numpy as np
import torch
from fvcore.transforms import HFlipTransform, TransformList
from torch.nn import functional as F
from detectron2.data.transforms import RandomRotation, RotationTransform, apply_transform_gens
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.modeling.test_time_augmentation import DatasetMapperTTA, GeneralizedRCNNWithTTA
from ..converters import HFlipConverter
class DensePoseDatasetMapperTTA(DatasetMapperTTA):
def __init__(self, cfg):
super().__init__(cfg=cfg)
self.angles = cfg.TEST.AUG.ROTATION_ANGLES
def __call__(self, dataset_dict):
ret = super().__call__(dataset_dict=dataset_dict)
numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy()
for angle in self.angles:
rotate = RandomRotation(angle=angle, expand=True)
new_numpy_image, tfms = apply_transform_gens([rotate], np.copy(numpy_image))
torch_image = torch.from_numpy(np.ascontiguousarray(new_numpy_image.transpose(2, 0, 1)))
dic = copy.deepcopy(dataset_dict)
# In DatasetMapperTTA, there is a pre_tfm transform (resize or no-op) that is
# added at the beginning of each TransformList. That's '.transforms[0]'.
dic["transforms"] = TransformList(
[ret[-1]["transforms"].transforms[0]] + tfms.transforms
)
dic["image"] = torch_image
ret.append(dic)
return ret
class DensePoseGeneralizedRCNNWithTTA(GeneralizedRCNNWithTTA):
def __init__(self, cfg, model, transform_data, tta_mapper=None, batch_size=1):
"""
Args:
cfg (CfgNode):
model (GeneralizedRCNN): a GeneralizedRCNN to apply TTA on.
transform_data (DensePoseTransformData): contains symmetry label
transforms used for horizontal flip
tta_mapper (callable): takes a dataset dict and returns a list of
augmented versions of the dataset dict. Defaults to
`DatasetMapperTTA(cfg)`.
batch_size (int): batch the augmented images into this batch size for inference.
"""
self._transform_data = transform_data.to(model.device)
super().__init__(cfg=cfg, model=model, tta_mapper=tta_mapper, batch_size=batch_size)
# the implementation follows closely the one from detectron2/modeling
def _inference_one_image(self, input):
"""
Args:
input (dict): one dataset dict with "image" field being a CHW tensor
Returns:
dict: one output dict
"""
orig_shape = (input["height"], input["width"])
# For some reason, resize with uint8 slightly increases box AP but decreases densepose AP
input["image"] = input["image"].to(torch.uint8)
augmented_inputs, tfms = self._get_augmented_inputs(input)
# Detect boxes from all augmented versions
with self._turn_off_roi_heads(["mask_on", "keypoint_on", "densepose_on"]):
# temporarily disable roi heads
all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms)
merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape)
if self.cfg.MODEL.MASK_ON or self.cfg.MODEL.DENSEPOSE_ON:
# Use the detected boxes to obtain new fields
augmented_instances = self._rescale_detected_boxes(
augmented_inputs, merged_instances, tfms
)
# run forward on the detected boxes
outputs = self._batch_inference(augmented_inputs, augmented_instances)
# Delete now useless variables to avoid being out of memory
del augmented_inputs, augmented_instances
# average the predictions
if self.cfg.MODEL.MASK_ON:
merged_instances.pred_masks = self._reduce_pred_masks(outputs, tfms)
if self.cfg.MODEL.DENSEPOSE_ON:
merged_instances.pred_densepose = self._reduce_pred_densepose(outputs, tfms)
# postprocess
merged_instances = detector_postprocess(merged_instances, *orig_shape)
return {"instances": merged_instances}
else:
return {"instances": merged_instances}
def _get_augmented_boxes(self, augmented_inputs, tfms):
# Heavily based on detectron2/modeling/test_time_augmentation.py
# Only difference is that RotationTransform is excluded from bbox computation
# 1: forward with all augmented images
outputs = self._batch_inference(augmented_inputs)
# 2: union the results
all_boxes = []
all_scores = []
all_classes = []
for output, tfm in zip(outputs, tfms):
# Need to inverse the transforms on boxes, to obtain results on original image
if not any(isinstance(t, RotationTransform) for t in tfm.transforms):
# Some transforms can't compute bbox correctly
pred_boxes = output.pred_boxes.tensor
original_pred_boxes = tfm.inverse().apply_box(pred_boxes.cpu().numpy())
all_boxes.append(torch.from_numpy(original_pred_boxes).to(pred_boxes.device))
all_scores.extend(output.scores)
all_classes.extend(output.pred_classes)
all_boxes = torch.cat(all_boxes, dim=0)
return all_boxes, all_scores, all_classes
def _reduce_pred_densepose(self, outputs, tfms):
# Should apply inverse transforms on densepose preds.
# We assume only rotation, resize & flip are used. pred_masks is a scale-invariant
# representation, so we handle the other ones specially
for idx, (output, tfm) in enumerate(zip(outputs, tfms)):
for t in tfm.transforms:
for attr in ["coarse_segm", "fine_segm", "u", "v"]:
setattr(
output.pred_densepose,
attr,
_inverse_rotation(
getattr(output.pred_densepose, attr), output.pred_boxes.tensor, t
),
)
if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
output.pred_densepose = HFlipConverter.convert(
output.pred_densepose, self._transform_data
)
self._incremental_avg_dp(outputs[0].pred_densepose, output.pred_densepose, idx)
return outputs[0].pred_densepose
# incrementally computed average: u_(n + 1) = u_n + (x_(n+1) - u_n) / (n + 1).
def _incremental_avg_dp(self, avg, new_el, idx):
for attr in ["coarse_segm", "fine_segm", "u", "v"]:
setattr(avg, attr, (getattr(avg, attr) * idx + getattr(new_el, attr)) / (idx + 1))
if idx:
# Deletion of the > 0 index intermediary values to prevent GPU OOM
setattr(new_el, attr, None)
return avg
def _inverse_rotation(densepose_attrs, boxes, transform):
# resample outputs to image size and rotate back the densepose preds
# on the rotated images to the space of the original image
if len(boxes) == 0 or not isinstance(transform, RotationTransform):
return densepose_attrs
boxes = boxes.int().cpu().numpy()
wh_boxes = boxes[:, 2:] - boxes[:, :2] # bboxes in the rotated space
inv_boxes = rotate_box_inverse(transform, boxes).astype(int) # bboxes in original image
wh_diff = (inv_boxes[:, 2:] - inv_boxes[:, :2] - wh_boxes) // 2 # diff between new/old bboxes
rotation_matrix = torch.tensor([transform.rm_image]).to(device=densepose_attrs.device).float()
rotation_matrix[:, :, -1] = 0
# To apply grid_sample for rotation, we need to have enough space to fit the original and
# rotated bboxes. l_bds and r_bds are the left/right bounds that will be used to
# crop the difference once the rotation is done
l_bds = np.maximum(0, -wh_diff)
for i in range(len(densepose_attrs)):
if min(wh_boxes[i]) <= 0:
continue
densepose_attr = densepose_attrs[[i]].clone()
# 1. Interpolate densepose attribute to size of the rotated bbox
densepose_attr = F.interpolate(densepose_attr, wh_boxes[i].tolist()[::-1], mode="bilinear")
# 2. Pad the interpolated attribute so it has room for the original + rotated bbox
densepose_attr = F.pad(densepose_attr, tuple(np.repeat(np.maximum(0, wh_diff[i]), 2)))
# 3. Compute rotation grid and transform
grid = F.affine_grid(rotation_matrix, size=densepose_attr.shape)
densepose_attr = F.grid_sample(densepose_attr, grid)
# 4. Compute right bounds and crop the densepose_attr to the size of the original bbox
r_bds = densepose_attr.shape[2:][::-1] - l_bds[i]
densepose_attr = densepose_attr[:, :, l_bds[i][1] : r_bds[1], l_bds[i][0] : r_bds[0]]
if min(densepose_attr.shape) > 0:
# Interpolate back to the original size of the densepose attribute
densepose_attr = F.interpolate(
densepose_attr, densepose_attrs.shape[-2:], mode="bilinear"
)
# Adding a very small probability to the background class to fill padded zones
densepose_attr[:, 0] += 1e-10
densepose_attrs[i] = densepose_attr
return densepose_attrs
def rotate_box_inverse(rot_tfm, rotated_box):
"""
rotated_box is a N * 4 array of [x0, y0, x1, y1] boxes
When a bbox is rotated, it gets bigger, because we need to surround the tilted bbox
So when a bbox is rotated then inverse-rotated, it is much bigger than the original
This function aims to invert the rotation on the box, but also resize it to its original size
"""
# 1. Compute the inverse rotation of the rotated bboxes (bigger than it )
invrot_box = rot_tfm.inverse().apply_box(rotated_box)
h, w = rotated_box[:, 3] - rotated_box[:, 1], rotated_box[:, 2] - rotated_box[:, 0]
ih, iw = invrot_box[:, 3] - invrot_box[:, 1], invrot_box[:, 2] - invrot_box[:, 0]
assert 2 * rot_tfm.abs_sin**2 != 1, "45 degrees angle can't be inverted"
# 2. Inverse the corresponding computation in the rotation transform
# to get the original height/width of the rotated boxes
orig_h = (h * rot_tfm.abs_cos - w * rot_tfm.abs_sin) / (1 - 2 * rot_tfm.abs_sin**2)
orig_w = (w * rot_tfm.abs_cos - h * rot_tfm.abs_sin) / (1 - 2 * rot_tfm.abs_sin**2)
# 3. Resize the inverse-rotated bboxes to their original size
invrot_box[:, 0] += (iw - orig_w) / 2
invrot_box[:, 1] += (ih - orig_h) / 2
invrot_box[:, 2] -= (iw - orig_w) / 2
invrot_box[:, 3] -= (ih - orig_h) / 2
return invrot_box
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from torch import nn
def initialize_module_params(module: nn.Module) -> None:
for name, param in module.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0)
elif "weight" in name:
nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from .chart import DensePoseChartPredictorOutput
from .chart_confidence import decorate_predictor_output_class_with_confidences
from .cse_confidence import decorate_cse_predictor_output_class_with_confidences
from .chart_result import (
DensePoseChartResult,
DensePoseChartResultWithConfidences,
quantize_densepose_chart_result,
compress_quantized_densepose_chart_result,
decompress_compressed_densepose_chart_result,
)
from .cse import DensePoseEmbeddingPredictorOutput
from .data_relative import DensePoseDataRelative
from .list import DensePoseList
from .mesh import Mesh, create_mesh
from .transform_data import DensePoseTransformData, normalized_coords_transform
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from dataclasses import dataclass
from typing import Union
import torch
@dataclass
class DensePoseChartPredictorOutput:
"""
Predictor output that contains segmentation and inner coordinates predictions for predefined
body parts:
* coarse segmentation, a tensor of shape [N, K, Hout, Wout]
* fine segmentation, a tensor of shape [N, C, Hout, Wout]
* U coordinates, a tensor of shape [N, C, Hout, Wout]
* V coordinates, a tensor of shape [N, C, Hout, Wout]
where
- N is the number of instances
- K is the number of coarse segmentation channels (
2 = foreground / background,
15 = one of 14 body parts / background)
- C is the number of fine segmentation channels (
24 fine body parts / background)
- Hout and Wout are height and width of predictions
"""
coarse_segm: torch.Tensor
fine_segm: torch.Tensor
u: torch.Tensor
v: torch.Tensor
def __len__(self):
"""
Number of instances (N) in the output
"""
return self.coarse_segm.size(0)
def __getitem__(
self, item: Union[int, slice, torch.BoolTensor]
) -> "DensePoseChartPredictorOutput":
"""
Get outputs for the selected instance(s)
Args:
item (int or slice or tensor): selected items
"""
if isinstance(item, int):
return DensePoseChartPredictorOutput(
coarse_segm=self.coarse_segm[item].unsqueeze(0),
fine_segm=self.fine_segm[item].unsqueeze(0),
u=self.u[item].unsqueeze(0),
v=self.v[item].unsqueeze(0),
)
else:
return DensePoseChartPredictorOutput(
coarse_segm=self.coarse_segm[item],
fine_segm=self.fine_segm[item],
u=self.u[item],
v=self.v[item],
)
def to(self, device: torch.device):
"""
Transfers all tensors to the given device
"""
coarse_segm = self.coarse_segm.to(device)
fine_segm = self.fine_segm.to(device)
u = self.u.to(device)
v = self.v.to(device)
return DensePoseChartPredictorOutput(coarse_segm=coarse_segm, fine_segm=fine_segm, u=u, v=v)
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from dataclasses import make_dataclass
from functools import lru_cache
from typing import Any, Optional
import torch
@lru_cache(maxsize=None)
def decorate_predictor_output_class_with_confidences(BasePredictorOutput: type) -> type:
"""
Create a new output class from an existing one by adding new attributes
related to confidence estimation:
- sigma_1 (tensor)
- sigma_2 (tensor)
- kappa_u (tensor)
- kappa_v (tensor)
- fine_segm_confidence (tensor)
- coarse_segm_confidence (tensor)
Details on confidence estimation parameters can be found in:
N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning
Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019
A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020
The new class inherits the provided `BasePredictorOutput` class,
it's name is composed of the name of the provided class and
"WithConfidences" suffix.
Args:
BasePredictorOutput (type): output type to which confidence data
is to be added, assumed to be a dataclass
Return:
New dataclass derived from the provided one that has attributes
for confidence estimation
"""
PredictorOutput = make_dataclass(
BasePredictorOutput.__name__ + "WithConfidences",
fields=[
("sigma_1", Optional[torch.Tensor], None),
("sigma_2", Optional[torch.Tensor], None),
("kappa_u", Optional[torch.Tensor], None),
("kappa_v", Optional[torch.Tensor], None),
("fine_segm_confidence", Optional[torch.Tensor], None),
("coarse_segm_confidence", Optional[torch.Tensor], None),
],
bases=(BasePredictorOutput,),
)
# add possibility to index PredictorOutput
def slice_if_not_none(data, item):
if data is None:
return None
if isinstance(item, int):
return data[item].unsqueeze(0)
return data[item]
def PredictorOutput_getitem(self, item):
PredictorOutput = type(self)
base_predictor_output_sliced = super(PredictorOutput, self).__getitem__(item)
return PredictorOutput(
**base_predictor_output_sliced.__dict__,
coarse_segm_confidence=slice_if_not_none(self.coarse_segm_confidence, item),
fine_segm_confidence=slice_if_not_none(self.fine_segm_confidence, item),
sigma_1=slice_if_not_none(self.sigma_1, item),
sigma_2=slice_if_not_none(self.sigma_2, item),
kappa_u=slice_if_not_none(self.kappa_u, item),
kappa_v=slice_if_not_none(self.kappa_v, item),
)
PredictorOutput.__getitem__ = PredictorOutput_getitem
def PredictorOutput_to(self, device: torch.device):
"""
Transfers all tensors to the given device
"""
PredictorOutput = type(self)
base_predictor_output_to = super(PredictorOutput, self).to(device) # pyre-ignore[16]
def to_device_if_tensor(var: Any):
if isinstance(var, torch.Tensor):
return var.to(device)
return var
return PredictorOutput(
**base_predictor_output_to.__dict__,
sigma_1=to_device_if_tensor(self.sigma_1),
sigma_2=to_device_if_tensor(self.sigma_2),
kappa_u=to_device_if_tensor(self.kappa_u),
kappa_v=to_device_if_tensor(self.kappa_v),
fine_segm_confidence=to_device_if_tensor(self.fine_segm_confidence),
coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence),
)
PredictorOutput.to = PredictorOutput_to
return PredictorOutput
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import torch
@dataclass
class DensePoseChartResult:
"""
DensePose results for chart-based methods represented by labels and inner
coordinates (U, V) of individual charts. Each chart is a 2D manifold
that has an associated label and is parameterized by two coordinates U and V.
Both U and V take values in [0, 1].
Thus the results are represented by two tensors:
- labels (tensor [H, W] of long): contains estimated label for each pixel of
the detection bounding box of size (H, W)
- uv (tensor [2, H, W] of float): contains estimated U and V coordinates
for each pixel of the detection bounding box of size (H, W)
"""
labels: torch.Tensor
uv: torch.Tensor
def to(self, device: torch.device):
"""
Transfers all tensors to the given device
"""
labels = self.labels.to(device)
uv = self.uv.to(device)
return DensePoseChartResult(labels=labels, uv=uv)
@dataclass
class DensePoseChartResultWithConfidences:
"""
We add confidence values to DensePoseChartResult
Thus the results are represented by two tensors:
- labels (tensor [H, W] of long): contains estimated label for each pixel of
the detection bounding box of size (H, W)
- uv (tensor [2, H, W] of float): contains estimated U and V coordinates
for each pixel of the detection bounding box of size (H, W)
Plus one [H, W] tensor of float for each confidence type
"""
labels: torch.Tensor
uv: torch.Tensor
sigma_1: Optional[torch.Tensor] = None
sigma_2: Optional[torch.Tensor] = None
kappa_u: Optional[torch.Tensor] = None
kappa_v: Optional[torch.Tensor] = None
fine_segm_confidence: Optional[torch.Tensor] = None
coarse_segm_confidence: Optional[torch.Tensor] = None
def to(self, device: torch.device):
"""
Transfers all tensors to the given device, except if their value is None
"""
def to_device_if_tensor(var: Any):
if isinstance(var, torch.Tensor):
return var.to(device)
return var
return DensePoseChartResultWithConfidences(
labels=self.labels.to(device),
uv=self.uv.to(device),
sigma_1=to_device_if_tensor(self.sigma_1),
sigma_2=to_device_if_tensor(self.sigma_2),
kappa_u=to_device_if_tensor(self.kappa_u),
kappa_v=to_device_if_tensor(self.kappa_v),
fine_segm_confidence=to_device_if_tensor(self.fine_segm_confidence),
coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence),
)
@dataclass
class DensePoseChartResultQuantized:
"""
DensePose results for chart-based methods represented by labels and quantized
inner coordinates (U, V) of individual charts. Each chart is a 2D manifold
that has an associated label and is parameterized by two coordinates U and V.
Both U and V take values in [0, 1].
Quantized coordinates Uq and Vq have uint8 values which are obtained as:
Uq = U * 255 (hence 0 <= Uq <= 255)
Vq = V * 255 (hence 0 <= Vq <= 255)
Thus the results are represented by one tensor:
- labels_uv_uint8 (tensor [3, H, W] of uint8): contains estimated label
and quantized coordinates Uq and Vq for each pixel of the detection
bounding box of size (H, W)
"""
labels_uv_uint8: torch.Tensor
def to(self, device: torch.device):
"""
Transfers all tensors to the given device
"""
labels_uv_uint8 = self.labels_uv_uint8.to(device)
return DensePoseChartResultQuantized(labels_uv_uint8=labels_uv_uint8)
@dataclass
class DensePoseChartResultCompressed:
"""
DensePose results for chart-based methods represented by a PNG-encoded string.
The tensor of quantized DensePose results of size [3, H, W] is considered
as an image with 3 color channels. PNG compression is applied and the result
is stored as a Base64-encoded string. The following attributes are defined:
- shape_chw (tuple of 3 int): contains shape of the result tensor
(number of channels, height, width)
- labels_uv_str (str): contains Base64-encoded results tensor of size
[3, H, W] compressed with PNG compression methods
"""
shape_chw: Tuple[int, int, int]
labels_uv_str: str
def quantize_densepose_chart_result(result: DensePoseChartResult) -> DensePoseChartResultQuantized:
"""
Applies quantization to DensePose chart-based result.
Args:
result (DensePoseChartResult): DensePose chart-based result
Return:
Quantized DensePose chart-based result (DensePoseChartResultQuantized)
"""
h, w = result.labels.shape
labels_uv_uint8 = torch.zeros([3, h, w], dtype=torch.uint8, device=result.labels.device)
labels_uv_uint8[0] = result.labels
labels_uv_uint8[1:] = (result.uv * 255).clamp(0, 255).byte()
return DensePoseChartResultQuantized(labels_uv_uint8=labels_uv_uint8)
def compress_quantized_densepose_chart_result(
result: DensePoseChartResultQuantized,
) -> DensePoseChartResultCompressed:
"""
Compresses quantized DensePose chart-based result
Args:
result (DensePoseChartResultQuantized): quantized DensePose chart-based result
Return:
Compressed DensePose chart-based result (DensePoseChartResultCompressed)
"""
import base64
import numpy as np
from io import BytesIO
from PIL import Image
labels_uv_uint8_np_chw = result.labels_uv_uint8.cpu().numpy()
labels_uv_uint8_np_hwc = np.moveaxis(labels_uv_uint8_np_chw, 0, -1)
im = Image.fromarray(labels_uv_uint8_np_hwc)
fstream = BytesIO()
im.save(fstream, format="png", optimize=True)
labels_uv_str = base64.encodebytes(fstream.getvalue()).decode()
shape_chw = labels_uv_uint8_np_chw.shape
return DensePoseChartResultCompressed(labels_uv_str=labels_uv_str, shape_chw=shape_chw)
def decompress_compressed_densepose_chart_result(
result: DensePoseChartResultCompressed,
) -> DensePoseChartResultQuantized:
"""
Decompresses DensePose chart-based result encoded into a base64 string
Args:
result (DensePoseChartResultCompressed): compressed DensePose chart result
Return:
Quantized DensePose chart-based result (DensePoseChartResultQuantized)
"""
import base64
import numpy as np
from io import BytesIO
from PIL import Image
fstream = BytesIO(base64.decodebytes(result.labels_uv_str.encode()))
im = Image.open(fstream)
labels_uv_uint8_np_chw = np.moveaxis(np.array(im, dtype=np.uint8), -1, 0)
return DensePoseChartResultQuantized(
labels_uv_uint8=torch.from_numpy(labels_uv_uint8_np_chw.reshape(result.shape_chw))
)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from dataclasses import dataclass
from typing import Union
import torch
@dataclass
class DensePoseEmbeddingPredictorOutput:
"""
Predictor output that contains embedding and coarse segmentation data:
* embedding: float tensor of size [N, D, H, W], contains estimated embeddings
* coarse_segm: float tensor of size [N, K, H, W]
Here D = MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
K = MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS
"""
embedding: torch.Tensor
coarse_segm: torch.Tensor
def __len__(self):
"""
Number of instances (N) in the output
"""
return self.coarse_segm.size(0)
def __getitem__(
self, item: Union[int, slice, torch.BoolTensor]
) -> "DensePoseEmbeddingPredictorOutput":
"""
Get outputs for the selected instance(s)
Args:
item (int or slice or tensor): selected items
"""
if isinstance(item, int):
return DensePoseEmbeddingPredictorOutput(
coarse_segm=self.coarse_segm[item].unsqueeze(0),
embedding=self.embedding[item].unsqueeze(0),
)
else:
return DensePoseEmbeddingPredictorOutput(
coarse_segm=self.coarse_segm[item], embedding=self.embedding[item]
)
def to(self, device: torch.device):
"""
Transfers all tensors to the given device
"""
coarse_segm = self.coarse_segm.to(device)
embedding = self.embedding.to(device)
return DensePoseEmbeddingPredictorOutput(coarse_segm=coarse_segm, embedding=embedding)
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from dataclasses import make_dataclass
from functools import lru_cache
from typing import Any, Optional
import torch
@lru_cache(maxsize=None)
def decorate_cse_predictor_output_class_with_confidences(BasePredictorOutput: type) -> type:
"""
Create a new output class from an existing one by adding new attributes
related to confidence estimation:
- coarse_segm_confidence (tensor)
Details on confidence estimation parameters can be found in:
N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning
Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019
A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020
The new class inherits the provided `BasePredictorOutput` class,
it's name is composed of the name of the provided class and
"WithConfidences" suffix.
Args:
BasePredictorOutput (type): output type to which confidence data
is to be added, assumed to be a dataclass
Return:
New dataclass derived from the provided one that has attributes
for confidence estimation
"""
PredictorOutput = make_dataclass(
BasePredictorOutput.__name__ + "WithConfidences",
fields=[
("coarse_segm_confidence", Optional[torch.Tensor], None),
],
bases=(BasePredictorOutput,),
)
# add possibility to index PredictorOutput
def slice_if_not_none(data, item):
if data is None:
return None
if isinstance(item, int):
return data[item].unsqueeze(0)
return data[item]
def PredictorOutput_getitem(self, item):
PredictorOutput = type(self)
base_predictor_output_sliced = super(PredictorOutput, self).__getitem__(item)
return PredictorOutput(
**base_predictor_output_sliced.__dict__,
coarse_segm_confidence=slice_if_not_none(self.coarse_segm_confidence, item),
)
PredictorOutput.__getitem__ = PredictorOutput_getitem
def PredictorOutput_to(self, device: torch.device):
"""
Transfers all tensors to the given device
"""
PredictorOutput = type(self)
base_predictor_output_to = super(PredictorOutput, self).to(device) # pyre-ignore[16]
def to_device_if_tensor(var: Any):
if isinstance(var, torch.Tensor):
return var.to(device)
return var
return PredictorOutput(
**base_predictor_output_to.__dict__,
coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence),
)
PredictorOutput.to = PredictorOutput_to
return PredictorOutput
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import numpy as np
import torch
from torch.nn import functional as F
from densepose.data.meshes.catalog import MeshCatalog
from densepose.structures.mesh import load_mesh_symmetry
from densepose.structures.transform_data import DensePoseTransformData
class DensePoseDataRelative:
"""
Dense pose relative annotations that can be applied to any bounding box:
x - normalized X coordinates [0, 255] of annotated points
y - normalized Y coordinates [0, 255] of annotated points
i - body part labels 0,...,24 for annotated points
u - body part U coordinates [0, 1] for annotated points
v - body part V coordinates [0, 1] for annotated points
segm - 256x256 segmentation mask with values 0,...,14
To obtain absolute x and y data wrt some bounding box one needs to first
divide the data by 256, multiply by the respective bounding box size
and add bounding box offset:
x_img = x0 + x_norm * w / 256.0
y_img = y0 + y_norm * h / 256.0
Segmentation masks are typically sampled to get image-based masks.
"""
# Key for normalized X coordinates in annotation dict
X_KEY = "dp_x"
# Key for normalized Y coordinates in annotation dict
Y_KEY = "dp_y"
# Key for U part coordinates in annotation dict (used in chart-based annotations)
U_KEY = "dp_U"
# Key for V part coordinates in annotation dict (used in chart-based annotations)
V_KEY = "dp_V"
# Key for I point labels in annotation dict (used in chart-based annotations)
I_KEY = "dp_I"
# Key for segmentation mask in annotation dict
S_KEY = "dp_masks"
# Key for vertex ids (used in continuous surface embeddings annotations)
VERTEX_IDS_KEY = "dp_vertex"
# Key for mesh id (used in continuous surface embeddings annotations)
MESH_NAME_KEY = "ref_model"
# Number of body parts in segmentation masks
N_BODY_PARTS = 14
# Number of parts in point labels
N_PART_LABELS = 24
MASK_SIZE = 256
def __init__(self, annotation, cleanup=False):
self.x = torch.as_tensor(annotation[DensePoseDataRelative.X_KEY])
self.y = torch.as_tensor(annotation[DensePoseDataRelative.Y_KEY])
if (
DensePoseDataRelative.I_KEY in annotation
and DensePoseDataRelative.U_KEY in annotation
and DensePoseDataRelative.V_KEY in annotation
):
self.i = torch.as_tensor(annotation[DensePoseDataRelative.I_KEY])
self.u = torch.as_tensor(annotation[DensePoseDataRelative.U_KEY])
self.v = torch.as_tensor(annotation[DensePoseDataRelative.V_KEY])
if (
DensePoseDataRelative.VERTEX_IDS_KEY in annotation
and DensePoseDataRelative.MESH_NAME_KEY in annotation
):
self.vertex_ids = torch.as_tensor(
annotation[DensePoseDataRelative.VERTEX_IDS_KEY], dtype=torch.long
)
self.mesh_id = MeshCatalog.get_mesh_id(annotation[DensePoseDataRelative.MESH_NAME_KEY])
if DensePoseDataRelative.S_KEY in annotation:
self.segm = DensePoseDataRelative.extract_segmentation_mask(annotation)
self.device = torch.device("cpu")
if cleanup:
DensePoseDataRelative.cleanup_annotation(annotation)
def to(self, device):
if self.device == device:
return self
new_data = DensePoseDataRelative.__new__(DensePoseDataRelative)
new_data.x = self.x.to(device)
new_data.y = self.y.to(device)
for attr in ["i", "u", "v", "vertex_ids", "segm"]:
if hasattr(self, attr):
setattr(new_data, attr, getattr(self, attr).to(device))
if hasattr(self, "mesh_id"):
new_data.mesh_id = self.mesh_id
new_data.device = device
return new_data
@staticmethod
def extract_segmentation_mask(annotation):
import pycocotools.mask as mask_utils
# TODO: annotation instance is accepted if it contains either
# DensePose segmentation or instance segmentation. However, here we
# only rely on DensePose segmentation
poly_specs = annotation[DensePoseDataRelative.S_KEY]
if isinstance(poly_specs, torch.Tensor):
# data is already given as mask tensors, no need to decode
return poly_specs
segm = torch.zeros((DensePoseDataRelative.MASK_SIZE,) * 2, dtype=torch.float32)
if isinstance(poly_specs, dict):
if poly_specs:
mask = mask_utils.decode(poly_specs)
segm[mask > 0] = 1
else:
for i in range(len(poly_specs)):
poly_i = poly_specs[i]
if poly_i:
mask_i = mask_utils.decode(poly_i)
segm[mask_i > 0] = i + 1
return segm
@staticmethod
def validate_annotation(annotation):
for key in [
DensePoseDataRelative.X_KEY,
DensePoseDataRelative.Y_KEY,
]:
if key not in annotation:
return False, "no {key} data in the annotation".format(key=key)
valid_for_iuv_setting = all(
key in annotation
for key in [
DensePoseDataRelative.I_KEY,
DensePoseDataRelative.U_KEY,
DensePoseDataRelative.V_KEY,
]
)
valid_for_cse_setting = all(
key in annotation
for key in [
DensePoseDataRelative.VERTEX_IDS_KEY,
DensePoseDataRelative.MESH_NAME_KEY,
]
)
if not valid_for_iuv_setting and not valid_for_cse_setting:
return (
False,
"expected either {} (IUV setting) or {} (CSE setting) annotations".format(
", ".join(
[
DensePoseDataRelative.I_KEY,
DensePoseDataRelative.U_KEY,
DensePoseDataRelative.V_KEY,
]
),
", ".join(
[
DensePoseDataRelative.VERTEX_IDS_KEY,
DensePoseDataRelative.MESH_NAME_KEY,
]
),
),
)
return True, None
@staticmethod
def cleanup_annotation(annotation):
for key in [
DensePoseDataRelative.X_KEY,
DensePoseDataRelative.Y_KEY,
DensePoseDataRelative.I_KEY,
DensePoseDataRelative.U_KEY,
DensePoseDataRelative.V_KEY,
DensePoseDataRelative.S_KEY,
DensePoseDataRelative.VERTEX_IDS_KEY,
DensePoseDataRelative.MESH_NAME_KEY,
]:
if key in annotation:
del annotation[key]
def apply_transform(self, transforms, densepose_transform_data):
self._transform_pts(transforms, densepose_transform_data)
if hasattr(self, "segm"):
self._transform_segm(transforms, densepose_transform_data)
def _transform_pts(self, transforms, dp_transform_data):
import detectron2.data.transforms as T
# NOTE: This assumes that HorizFlipTransform is the only one that does flip
do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
if do_hflip:
self.x = self.MASK_SIZE - self.x
if hasattr(self, "i"):
self._flip_iuv_semantics(dp_transform_data)
if hasattr(self, "vertex_ids"):
self._flip_vertices()
for t in transforms.transforms:
if isinstance(t, T.RotationTransform):
xy_scale = np.array((t.w, t.h)) / DensePoseDataRelative.MASK_SIZE
xy = t.apply_coords(np.stack((self.x, self.y), axis=1) * xy_scale)
self.x, self.y = torch.tensor(xy / xy_scale, dtype=self.x.dtype).T
def _flip_iuv_semantics(self, dp_transform_data: DensePoseTransformData) -> None:
i_old = self.i.clone()
uv_symmetries = dp_transform_data.uv_symmetries
pt_label_symmetries = dp_transform_data.point_label_symmetries
for i in range(self.N_PART_LABELS):
if i + 1 in i_old:
annot_indices_i = i_old == i + 1
if pt_label_symmetries[i + 1] != i + 1:
self.i[annot_indices_i] = pt_label_symmetries[i + 1]
u_loc = (self.u[annot_indices_i] * 255).long()
v_loc = (self.v[annot_indices_i] * 255).long()
self.u[annot_indices_i] = uv_symmetries["U_transforms"][i][v_loc, u_loc].to(
device=self.u.device
)
self.v[annot_indices_i] = uv_symmetries["V_transforms"][i][v_loc, u_loc].to(
device=self.v.device
)
def _flip_vertices(self):
mesh_info = MeshCatalog[MeshCatalog.get_mesh_name(self.mesh_id)]
mesh_symmetry = (
load_mesh_symmetry(mesh_info.symmetry) if mesh_info.symmetry is not None else None
)
self.vertex_ids = mesh_symmetry["vertex_transforms"][self.vertex_ids]
def _transform_segm(self, transforms, dp_transform_data):
import detectron2.data.transforms as T
# NOTE: This assumes that HorizFlipTransform is the only one that does flip
do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
if do_hflip:
self.segm = torch.flip(self.segm, [1])
self._flip_segm_semantics(dp_transform_data)
for t in transforms.transforms:
if isinstance(t, T.RotationTransform):
self._transform_segm_rotation(t)
def _flip_segm_semantics(self, dp_transform_data):
old_segm = self.segm.clone()
mask_label_symmetries = dp_transform_data.mask_label_symmetries
for i in range(self.N_BODY_PARTS):
if mask_label_symmetries[i + 1] != i + 1:
self.segm[old_segm == i + 1] = mask_label_symmetries[i + 1]
def _transform_segm_rotation(self, rotation):
self.segm = F.interpolate(self.segm[None, None, :], (rotation.h, rotation.w)).numpy()
self.segm = torch.tensor(rotation.apply_segmentation(self.segm[0, 0]))[None, None, :]
self.segm = F.interpolate(self.segm, [DensePoseDataRelative.MASK_SIZE] * 2)[0, 0]
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import torch
from densepose.structures.data_relative import DensePoseDataRelative
class DensePoseList:
_TORCH_DEVICE_CPU = torch.device("cpu")
def __init__(self, densepose_datas, boxes_xyxy_abs, image_size_hw, device=_TORCH_DEVICE_CPU):
assert len(densepose_datas) == len(
boxes_xyxy_abs
), "Attempt to initialize DensePoseList with {} DensePose datas " "and {} boxes".format(
len(densepose_datas), len(boxes_xyxy_abs)
)
self.densepose_datas = []
for densepose_data in densepose_datas:
assert isinstance(densepose_data, DensePoseDataRelative) or densepose_data is None, (
"Attempt to initialize DensePoseList with DensePose datas "
"of type {}, expected DensePoseDataRelative".format(type(densepose_data))
)
densepose_data_ondevice = (
densepose_data.to(device) if densepose_data is not None else None
)
self.densepose_datas.append(densepose_data_ondevice)
self.boxes_xyxy_abs = boxes_xyxy_abs.to(device)
self.image_size_hw = image_size_hw
self.device = device
def to(self, device):
if self.device == device:
return self
return DensePoseList(self.densepose_datas, self.boxes_xyxy_abs, self.image_size_hw, device)
def __iter__(self):
return iter(self.densepose_datas)
def __len__(self):
return len(self.densepose_datas)
def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_instances={}, ".format(len(self.densepose_datas))
s += "image_width={}, ".format(self.image_size_hw[1])
s += "image_height={})".format(self.image_size_hw[0])
return s
def __getitem__(self, item):
if isinstance(item, int):
densepose_data_rel = self.densepose_datas[item]
return densepose_data_rel
elif isinstance(item, slice):
densepose_datas_rel = self.densepose_datas[item]
boxes_xyxy_abs = self.boxes_xyxy_abs[item]
return DensePoseList(
densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device
)
elif isinstance(item, torch.Tensor) and (item.dtype == torch.bool):
densepose_datas_rel = [self.densepose_datas[i] for i, x in enumerate(item) if x > 0]
boxes_xyxy_abs = self.boxes_xyxy_abs[item]
return DensePoseList(
densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device
)
else:
densepose_datas_rel = [self.densepose_datas[i] for i in item]
boxes_xyxy_abs = self.boxes_xyxy_abs[item]
return DensePoseList(
densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device
)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
import pickle
from functools import lru_cache
from typing import Dict, Optional, Tuple
import torch
from detectron2.utils.file_io import PathManager
from densepose.data.meshes.catalog import MeshCatalog, MeshInfo
def _maybe_copy_to_device(
attribute: Optional[torch.Tensor], device: torch.device
) -> Optional[torch.Tensor]:
if attribute is None:
return None
return attribute.to(device)
class Mesh:
def __init__(
self,
vertices: Optional[torch.Tensor] = None,
faces: Optional[torch.Tensor] = None,
geodists: Optional[torch.Tensor] = None,
symmetry: Optional[Dict[str, torch.Tensor]] = None,
texcoords: Optional[torch.Tensor] = None,
mesh_info: Optional[MeshInfo] = None,
device: Optional[torch.device] = None,
):
"""
Args:
vertices (tensor [N, 3] of float32): vertex coordinates in 3D
faces (tensor [M, 3] of long): triangular face represented as 3
vertex indices
geodists (tensor [N, N] of float32): geodesic distances from
vertex `i` to vertex `j` (optional, default: None)
symmetry (dict: str -> tensor): various mesh symmetry data:
- "vertex_transforms": vertex mapping under horizontal flip,
tensor of size [N] of type long; vertex `i` is mapped to
vertex `tensor[i]` (optional, default: None)
texcoords (tensor [N, 2] of float32): texture coordinates, i.e. global
and normalized mesh UVs (optional, default: None)
mesh_info (MeshInfo type): necessary to load the attributes on-the-go,
can be used instead of passing all the variables one by one
device (torch.device): device of the Mesh. If not provided, will use
the device of the vertices
"""
self._vertices = vertices
self._faces = faces
self._geodists = geodists
self._symmetry = symmetry
self._texcoords = texcoords
self.mesh_info = mesh_info
self.device = device
assert self._vertices is not None or self.mesh_info is not None
all_fields = [self._vertices, self._faces, self._geodists, self._texcoords]
if self.device is None:
for field in all_fields:
if field is not None:
self.device = field.device
break
if self.device is None and symmetry is not None:
for key in symmetry:
self.device = symmetry[key].device
break
self.device = torch.device("cpu") if self.device is None else self.device
assert all([var.device == self.device for var in all_fields if var is not None])
if symmetry:
assert all(symmetry[key].device == self.device for key in symmetry)
if texcoords and vertices:
assert len(vertices) == len(texcoords)
def to(self, device: torch.device):
device_symmetry = self._symmetry
if device_symmetry:
device_symmetry = {key: value.to(device) for key, value in device_symmetry.items()}
return Mesh(
_maybe_copy_to_device(self._vertices, device),
_maybe_copy_to_device(self._faces, device),
_maybe_copy_to_device(self._geodists, device),
device_symmetry,
_maybe_copy_to_device(self._texcoords, device),
self.mesh_info,
device,
)
@property
def vertices(self):
if self._vertices is None and self.mesh_info is not None:
self._vertices = load_mesh_data(self.mesh_info.data, "vertices", self.device)
return self._vertices
@property
def faces(self):
if self._faces is None and self.mesh_info is not None:
self._faces = load_mesh_data(self.mesh_info.data, "faces", self.device)
return self._faces
@property
def geodists(self):
if self._geodists is None and self.mesh_info is not None:
self._geodists = load_mesh_auxiliary_data(self.mesh_info.geodists, self.device)
return self._geodists
@property
def symmetry(self):
if self._symmetry is None and self.mesh_info is not None:
self._symmetry = load_mesh_symmetry(self.mesh_info.symmetry, self.device)
return self._symmetry
@property
def texcoords(self):
if self._texcoords is None and self.mesh_info is not None:
self._texcoords = load_mesh_auxiliary_data(self.mesh_info.texcoords, self.device)
return self._texcoords
def get_geodists(self):
if self.geodists is None:
self.geodists = self._compute_geodists()
return self.geodists
def _compute_geodists(self):
# TODO: compute using Laplace-Beltrami
geodists = None
return geodists
def load_mesh_data(
mesh_fpath: str, field: str, device: Optional[torch.device] = None
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
with PathManager.open(mesh_fpath, "rb") as hFile:
# pyre-fixme[7]: Expected `Tuple[Optional[Tensor], Optional[Tensor]]` but
# got `Tensor`.
return torch.as_tensor(pickle.load(hFile)[field], dtype=torch.float).to(device)
return None
def load_mesh_auxiliary_data(
fpath: str, device: Optional[torch.device] = None
) -> Optional[torch.Tensor]:
fpath_local = PathManager.get_local_path(fpath)
with PathManager.open(fpath_local, "rb") as hFile:
return torch.as_tensor(pickle.load(hFile), dtype=torch.float).to(device)
return None
@lru_cache()
def load_mesh_symmetry(
symmetry_fpath: str, device: Optional[torch.device] = None
) -> Optional[Dict[str, torch.Tensor]]:
with PathManager.open(symmetry_fpath, "rb") as hFile:
symmetry_loaded = pickle.load(hFile)
symmetry = {
"vertex_transforms": torch.as_tensor(
symmetry_loaded["vertex_transforms"], dtype=torch.long
).to(device),
}
return symmetry
return None
@lru_cache()
def create_mesh(mesh_name: str, device: Optional[torch.device] = None) -> Mesh:
return Mesh(mesh_info=MeshCatalog[mesh_name], device=device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment