Commit 3144257c authored by mashun1's avatar mashun1
Browse files

catvton

parents
Pipeline #1744 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (leoxiaobin@gmail.com)
# Modified by Bowen Cheng (bcheng9@illinois.edu)
# Adapted from https://github.com/HRNet/Higher-HRNet-Human-Pose-Estimation/blob/master/lib/models/pose_higher_hrnet.py # noqa
# ------------------------------------------------------------------------------
# pyre-unsafe
from __future__ import absolute_import, division, print_function
import logging
import torch.nn as nn
from detectron2.layers import ShapeSpec
from detectron2.modeling.backbone import BACKBONE_REGISTRY
from detectron2.modeling.backbone.backbone import Backbone
BN_MOMENTUM = 0.1
logger = logging.getLogger(__name__)
__all__ = ["build_pose_hrnet_backbone", "PoseHigherResolutionNet"]
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class HighResolutionModule(nn.Module):
"""HighResolutionModule
Building block of the PoseHigherResolutionNet (see lower)
arXiv: https://arxiv.org/abs/1908.10357
Args:
num_branches (int): number of branches of the modyle
blocks (str): type of block of the module
num_blocks (int): number of blocks of the module
num_inchannels (int): number of input channels of the module
num_channels (list): number of channels of each branch
multi_scale_output (bool): only used by the last module of PoseHigherResolutionNet
"""
def __init__(
self,
num_branches,
blocks,
num_blocks,
num_inchannels,
num_channels,
multi_scale_output=True,
):
super(HighResolutionModule, self).__init__()
self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels)
self.num_inchannels = num_inchannels
self.num_branches = num_branches
self.multi_scale_output = multi_scale_output
self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(True)
def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels):
if num_branches != len(num_blocks):
error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
num_branches, len(num_channels)
)
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
num_branches, len(num_inchannels)
)
logger.error(error_msg)
raise ValueError(error_msg)
def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
downsample = None
if (
stride != 1
or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion
):
downsample = nn.Sequential(
nn.Conv2d(
self.num_inchannels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM),
)
layers = []
layers.append(
block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)
)
self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
for _ in range(1, num_blocks[branch_index]):
layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
num_inchannels = self.num_inchannels
fuse_layers = []
for i in range(num_branches if self.multi_scale_output else 1):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
nn.BatchNorm2d(num_inchannels[i]),
nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"),
)
)
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i - j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_outchannels_conv3x3,
3,
2,
1,
bias=False,
),
nn.BatchNorm2d(num_outchannels_conv3x3),
)
)
else:
num_outchannels_conv3x3 = num_inchannels[j]
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_outchannels_conv3x3,
3,
2,
1,
bias=False,
),
nn.BatchNorm2d(num_outchannels_conv3x3),
nn.ReLU(True),
)
)
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
return self.num_inchannels
def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
else:
z = self.fuse_layers[i][j](x[j])[:, :, : y.shape[2], : y.shape[3]]
y = y + z
x_fuse.append(self.relu(y))
return x_fuse
blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck}
class PoseHigherResolutionNet(Backbone):
"""PoseHigherResolutionNet
Composed of several HighResolutionModule tied together with ConvNets
Adapted from the GitHub version to fit with HRFPN and the Detectron2 infrastructure
arXiv: https://arxiv.org/abs/1908.10357
"""
def __init__(self, cfg, **kwargs):
self.inplanes = cfg.MODEL.HRNET.STEM_INPLANES
super(PoseHigherResolutionNet, self).__init__()
# stem net
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(Bottleneck, 64, 4)
self.stage2_cfg = cfg.MODEL.HRNET.STAGE2
num_channels = self.stage2_cfg.NUM_CHANNELS
block = blocks_dict[self.stage2_cfg.BLOCK]
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition1 = self._make_transition_layer([256], num_channels)
self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
self.stage3_cfg = cfg.MODEL.HRNET.STAGE3
num_channels = self.stage3_cfg.NUM_CHANNELS
block = blocks_dict[self.stage3_cfg.BLOCK]
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
self.stage4_cfg = cfg.MODEL.HRNET.STAGE4
num_channels = self.stage4_cfg.NUM_CHANNELS
block = blocks_dict[self.stage4_cfg.BLOCK]
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels, multi_scale_output=True
)
self._out_features = []
self._out_feature_channels = {}
self._out_feature_strides = {}
for i in range(cfg.MODEL.HRNET.STAGE4.NUM_BRANCHES):
self._out_features.append("p%d" % (i + 1))
self._out_feature_channels.update(
{self._out_features[-1]: cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS[i]}
)
self._out_feature_strides.update({self._out_features[-1]: 1})
def _get_deconv_cfg(self, deconv_kernel):
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
return deconv_kernel, padding, output_padding
def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
nn.Conv2d(
num_channels_pre_layer[i],
num_channels_cur_layer[i],
3,
1,
1,
bias=False,
),
nn.BatchNorm2d(num_channels_cur_layer[i]),
nn.ReLU(inplace=True),
)
)
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i + 1 - num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = (
num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
)
conv3x3s.append(
nn.Sequential(
nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
nn.BatchNorm2d(outchannels),
nn.ReLU(inplace=True),
)
)
transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
num_modules = layer_config["NUM_MODULES"]
num_branches = layer_config["NUM_BRANCHES"]
num_blocks = layer_config["NUM_BLOCKS"]
num_channels = layer_config["NUM_CHANNELS"]
block = blocks_dict[layer_config["BLOCK"]]
modules = []
for i in range(num_modules):
# multi_scale_output is only used last module
if not multi_scale_output and i == num_modules - 1:
reset_multi_scale_output = False
else:
reset_multi_scale_output = True
modules.append(
HighResolutionModule(
num_branches,
block,
num_blocks,
num_inchannels,
num_channels,
reset_multi_scale_output,
)
)
num_inchannels = modules[-1].get_num_inchannels()
return nn.Sequential(*modules), num_inchannels
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg.NUM_BRANCHES):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg.NUM_BRANCHES):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg.NUM_BRANCHES):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage4(x_list)
assert len(self._out_features) == len(y_list)
return dict(zip(self._out_features, y_list)) # final_outputs
@BACKBONE_REGISTRY.register()
def build_pose_hrnet_backbone(cfg, input_shape: ShapeSpec):
model = PoseHigherResolutionNet(cfg)
return model
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from dataclasses import fields
from typing import Any, List
import torch
from detectron2.structures import Instances
def densepose_inference(densepose_predictor_output: Any, detections: List[Instances]) -> None:
"""
Splits DensePose predictor outputs into chunks, each chunk corresponds to
detections on one image. Predictor output chunks are stored in `pred_densepose`
attribute of the corresponding `Instances` object.
Args:
densepose_predictor_output: a dataclass instance (can be of different types,
depending on predictor used for inference). Each field can be `None`
(if the corresponding output was not inferred) or a tensor of size
[N, ...], where N = N_1 + N_2 + .. + N_k is a total number of
detections on all images, N_1 is the number of detections on image 1,
N_2 is the number of detections on image 2, etc.
detections: a list of objects of type `Instance`, k-th object corresponds
to detections on k-th image.
"""
k = 0
for detection_i in detections:
if densepose_predictor_output is None:
# don't add `pred_densepose` attribute
continue
n_i = detection_i.__len__()
PredictorOutput = type(densepose_predictor_output)
output_i_dict = {}
# we assume here that `densepose_predictor_output` is a dataclass object
for field in fields(densepose_predictor_output):
field_value = getattr(densepose_predictor_output, field.name)
# slice tensors
if isinstance(field_value, torch.Tensor):
output_i_dict[field.name] = field_value[k : k + n_i]
# leave others as is
else:
output_i_dict[field.name] = field_value
detection_i.pred_densepose = PredictorOutput(**output_i_dict)
k += n_i
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from .chart import DensePoseChartLoss
from .chart_with_confidences import DensePoseChartWithConfidenceLoss
from .cse import DensePoseCseLoss
from .registry import DENSEPOSE_LOSS_REGISTRY
__all__ = [
"DensePoseChartLoss",
"DensePoseChartWithConfidenceLoss",
"DensePoseCseLoss",
"DENSEPOSE_LOSS_REGISTRY",
]
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from typing import Any, List
import torch
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.structures import Instances
from .mask_or_segm import MaskOrSegmentationLoss
from .registry import DENSEPOSE_LOSS_REGISTRY
from .utils import (
BilinearInterpolationHelper,
ChartBasedAnnotationsAccumulator,
LossDict,
extract_packed_annotations_from_matches,
)
@DENSEPOSE_LOSS_REGISTRY.register()
class DensePoseChartLoss:
"""
DensePose loss for chart-based training. A mesh is split into charts,
each chart is given a label (I) and parametrized by 2 coordinates referred to
as U and V. Ground truth consists of a number of points annotated with
I, U and V values and coarse segmentation S defined for all pixels of the
object bounding box. In some cases (see `COARSE_SEGM_TRAINED_BY_MASKS`),
semantic segmentation annotations can be used as ground truth inputs as well.
Estimated values are tensors:
* U coordinates, tensor of shape [N, C, S, S]
* V coordinates, tensor of shape [N, C, S, S]
* fine segmentation estimates, tensor of shape [N, C, S, S] with raw unnormalized
scores for each fine segmentation label at each location
* coarse segmentation estimates, tensor of shape [N, D, S, S] with raw unnormalized
scores for each coarse segmentation label at each location
where N is the number of detections, C is the number of fine segmentation
labels, S is the estimate size ( = width = height) and D is the number of
coarse segmentation channels.
The losses are:
* regression (smooth L1) loss for U and V coordinates
* cross entropy loss for fine (I) and coarse (S) segmentations
Each loss has an associated weight
"""
def __init__(self, cfg: CfgNode):
"""
Initialize chart-based loss from configuration options
Args:
cfg (CfgNode): configuration options
"""
# fmt: off
self.heatmap_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE
self.w_points = cfg.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS
self.w_part = cfg.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS
self.w_segm = cfg.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS
self.n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS
# fmt: on
self.segm_trained_by_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
self.segm_loss = MaskOrSegmentationLoss(cfg)
def __call__(
self, proposals_with_gt: List[Instances], densepose_predictor_outputs: Any, **kwargs
) -> LossDict:
"""
Produce chart-based DensePose losses
Args:
proposals_with_gt (list of Instances): detections with associated ground truth data
densepose_predictor_outputs: an object of a dataclass that contains predictor outputs
with estimated values; assumed to have the following attributes:
* coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S]
* fine_segm - fine segmentation estimates, tensor of shape [N, C, S, S]
* u - U coordinate estimates per fine labels, tensor of shape [N, C, S, S]
* v - V coordinate estimates per fine labels, tensor of shape [N, C, S, S]
where N is the number of detections, C is the number of fine segmentation
labels, S is the estimate size ( = width = height) and D is the number of
coarse segmentation channels.
Return:
dict: str -> tensor: dict of losses with the following entries:
* `loss_densepose_U`: smooth L1 loss for U coordinate estimates
* `loss_densepose_V`: smooth L1 loss for V coordinate estimates
* `loss_densepose_I`: cross entropy for raw unnormalized scores for fine
segmentation estimates given ground truth labels;
* `loss_densepose_S`: cross entropy for raw unnormalized scores for coarse
segmentation estimates given ground truth labels;
"""
# densepose outputs are computed for all images and all bounding boxes;
# i.e. if a batch has 4 images with (3, 1, 2, 1) proposals respectively,
# the outputs will have size(0) == 3+1+2+1 == 7
if not len(proposals_with_gt):
return self.produce_fake_densepose_losses(densepose_predictor_outputs)
accumulator = ChartBasedAnnotationsAccumulator()
packed_annotations = extract_packed_annotations_from_matches(proposals_with_gt, accumulator)
# NOTE: we need to keep the same computation graph on all the GPUs to
# perform reduction properly. Hence even if we have no data on one
# of the GPUs, we still need to generate the computation graph.
# Add fake (zero) loss in the form Tensor.sum() * 0
if packed_annotations is None:
return self.produce_fake_densepose_losses(densepose_predictor_outputs)
h, w = densepose_predictor_outputs.u.shape[2:]
interpolator = BilinearInterpolationHelper.from_matches(
packed_annotations,
(h, w),
)
j_valid_fg = interpolator.j_valid * ( # pyre-ignore[16]
packed_annotations.fine_segm_labels_gt > 0
)
# pyre-fixme[6]: For 1st param expected `Tensor` but got `int`.
if not torch.any(j_valid_fg):
return self.produce_fake_densepose_losses(densepose_predictor_outputs)
losses_uv = self.produce_densepose_losses_uv(
proposals_with_gt,
densepose_predictor_outputs,
packed_annotations,
interpolator,
j_valid_fg, # pyre-ignore[6]
)
losses_segm = self.produce_densepose_losses_segm(
proposals_with_gt,
densepose_predictor_outputs,
packed_annotations,
interpolator,
j_valid_fg, # pyre-ignore[6]
)
return {**losses_uv, **losses_segm}
def produce_fake_densepose_losses(self, densepose_predictor_outputs: Any) -> LossDict:
"""
Fake losses for fine segmentation and U/V coordinates. These are used when
no suitable ground truth data was found in a batch. The loss has a value 0
and is primarily used to construct the computation graph, so that
`DistributedDataParallel` has similar graphs on all GPUs and can perform
reduction properly.
Args:
densepose_predictor_outputs: DensePose predictor outputs, an object
of a dataclass that is assumed to have the following attributes:
* fine_segm - fine segmentation estimates, tensor of shape [N, C, S, S]
* u - U coordinate estimates per fine labels, tensor of shape [N, C, S, S]
* v - V coordinate estimates per fine labels, tensor of shape [N, C, S, S]
Return:
dict: str -> tensor: dict of losses with the following entries:
* `loss_densepose_U`: has value 0
* `loss_densepose_V`: has value 0
* `loss_densepose_I`: has value 0
* `loss_densepose_S`: has value 0
"""
losses_uv = self.produce_fake_densepose_losses_uv(densepose_predictor_outputs)
losses_segm = self.produce_fake_densepose_losses_segm(densepose_predictor_outputs)
return {**losses_uv, **losses_segm}
def produce_fake_densepose_losses_uv(self, densepose_predictor_outputs: Any) -> LossDict:
"""
Fake losses for U/V coordinates. These are used when no suitable ground
truth data was found in a batch. The loss has a value 0
and is primarily used to construct the computation graph, so that
`DistributedDataParallel` has similar graphs on all GPUs and can perform
reduction properly.
Args:
densepose_predictor_outputs: DensePose predictor outputs, an object
of a dataclass that is assumed to have the following attributes:
* u - U coordinate estimates per fine labels, tensor of shape [N, C, S, S]
* v - V coordinate estimates per fine labels, tensor of shape [N, C, S, S]
Return:
dict: str -> tensor: dict of losses with the following entries:
* `loss_densepose_U`: has value 0
* `loss_densepose_V`: has value 0
"""
return {
"loss_densepose_U": densepose_predictor_outputs.u.sum() * 0,
"loss_densepose_V": densepose_predictor_outputs.v.sum() * 0,
}
def produce_fake_densepose_losses_segm(self, densepose_predictor_outputs: Any) -> LossDict:
"""
Fake losses for fine / coarse segmentation. These are used when
no suitable ground truth data was found in a batch. The loss has a value 0
and is primarily used to construct the computation graph, so that
`DistributedDataParallel` has similar graphs on all GPUs and can perform
reduction properly.
Args:
densepose_predictor_outputs: DensePose predictor outputs, an object
of a dataclass that is assumed to have the following attributes:
* fine_segm - fine segmentation estimates, tensor of shape [N, C, S, S]
* coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S]
Return:
dict: str -> tensor: dict of losses with the following entries:
* `loss_densepose_I`: has value 0
* `loss_densepose_S`: has value 0, added only if `segm_trained_by_masks` is False
"""
losses = {
"loss_densepose_I": densepose_predictor_outputs.fine_segm.sum() * 0,
"loss_densepose_S": self.segm_loss.fake_value(densepose_predictor_outputs),
}
return losses
def produce_densepose_losses_uv(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
packed_annotations: Any,
interpolator: BilinearInterpolationHelper,
j_valid_fg: torch.Tensor,
) -> LossDict:
"""
Compute losses for U/V coordinates: smooth L1 loss between
estimated coordinates and the ground truth.
Args:
proposals_with_gt (list of Instances): detections with associated ground truth data
densepose_predictor_outputs: DensePose predictor outputs, an object
of a dataclass that is assumed to have the following attributes:
* u - U coordinate estimates per fine labels, tensor of shape [N, C, S, S]
* v - V coordinate estimates per fine labels, tensor of shape [N, C, S, S]
Return:
dict: str -> tensor: dict of losses with the following entries:
* `loss_densepose_U`: smooth L1 loss for U coordinate estimates
* `loss_densepose_V`: smooth L1 loss for V coordinate estimates
"""
u_gt = packed_annotations.u_gt[j_valid_fg]
u_est = interpolator.extract_at_points(densepose_predictor_outputs.u)[j_valid_fg]
v_gt = packed_annotations.v_gt[j_valid_fg]
v_est = interpolator.extract_at_points(densepose_predictor_outputs.v)[j_valid_fg]
return {
"loss_densepose_U": F.smooth_l1_loss(u_est, u_gt, reduction="sum") * self.w_points,
"loss_densepose_V": F.smooth_l1_loss(v_est, v_gt, reduction="sum") * self.w_points,
}
def produce_densepose_losses_segm(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
packed_annotations: Any,
interpolator: BilinearInterpolationHelper,
j_valid_fg: torch.Tensor,
) -> LossDict:
"""
Losses for fine / coarse segmentation: cross-entropy
for segmentation unnormalized scores given ground truth labels at
annotated points for fine segmentation and dense mask annotations
for coarse segmentation.
Args:
proposals_with_gt (list of Instances): detections with associated ground truth data
densepose_predictor_outputs: DensePose predictor outputs, an object
of a dataclass that is assumed to have the following attributes:
* fine_segm - fine segmentation estimates, tensor of shape [N, C, S, S]
* coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S]
Return:
dict: str -> tensor: dict of losses with the following entries:
* `loss_densepose_I`: cross entropy for raw unnormalized scores for fine
segmentation estimates given ground truth labels
* `loss_densepose_S`: cross entropy for raw unnormalized scores for coarse
segmentation estimates given ground truth labels;
may be included if coarse segmentation is only trained
using DensePose ground truth; if additional supervision through
instance segmentation data is performed (`segm_trained_by_masks` is True),
this loss is handled by `produce_mask_losses` instead
"""
fine_segm_gt = packed_annotations.fine_segm_labels_gt[
interpolator.j_valid # pyre-ignore[16]
]
fine_segm_est = interpolator.extract_at_points(
densepose_predictor_outputs.fine_segm,
slice_fine_segm=slice(None),
w_ylo_xlo=interpolator.w_ylo_xlo[:, None], # pyre-ignore[16]
w_ylo_xhi=interpolator.w_ylo_xhi[:, None], # pyre-ignore[16]
w_yhi_xlo=interpolator.w_yhi_xlo[:, None], # pyre-ignore[16]
w_yhi_xhi=interpolator.w_yhi_xhi[:, None], # pyre-ignore[16]
)[interpolator.j_valid, :]
return {
"loss_densepose_I": F.cross_entropy(fine_segm_est, fine_segm_gt.long()) * self.w_part,
"loss_densepose_S": self.segm_loss(
proposals_with_gt, densepose_predictor_outputs, packed_annotations
)
* self.w_segm,
}
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import math
from typing import Any, List
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.structures import Instances
from .. import DensePoseConfidenceModelConfig, DensePoseUVConfidenceType
from .chart import DensePoseChartLoss
from .registry import DENSEPOSE_LOSS_REGISTRY
from .utils import BilinearInterpolationHelper, LossDict
@DENSEPOSE_LOSS_REGISTRY.register()
class DensePoseChartWithConfidenceLoss(DensePoseChartLoss):
""" """
def __init__(self, cfg: CfgNode):
super().__init__(cfg)
self.confidence_model_cfg = DensePoseConfidenceModelConfig.from_cfg(cfg)
if self.confidence_model_cfg.uv_confidence.type == DensePoseUVConfidenceType.IID_ISO:
self.uv_loss_with_confidences = IIDIsotropicGaussianUVLoss(
self.confidence_model_cfg.uv_confidence.epsilon
)
elif self.confidence_model_cfg.uv_confidence.type == DensePoseUVConfidenceType.INDEP_ANISO:
self.uv_loss_with_confidences = IndepAnisotropicGaussianUVLoss(
self.confidence_model_cfg.uv_confidence.epsilon
)
def produce_fake_densepose_losses_uv(self, densepose_predictor_outputs: Any) -> LossDict:
"""
Overrides fake losses for fine segmentation and U/V coordinates to
include computation graphs for additional confidence parameters.
These are used when no suitable ground truth data was found in a batch.
The loss has a value 0 and is primarily used to construct the computation graph,
so that `DistributedDataParallel` has similar graphs on all GPUs and can
perform reduction properly.
Args:
densepose_predictor_outputs: DensePose predictor outputs, an object
of a dataclass that is assumed to have the following attributes:
* fine_segm - fine segmentation estimates, tensor of shape [N, C, S, S]
* u - U coordinate estimates per fine labels, tensor of shape [N, C, S, S]
* v - V coordinate estimates per fine labels, tensor of shape [N, C, S, S]
Return:
dict: str -> tensor: dict of losses with the following entries:
* `loss_densepose_U`: has value 0
* `loss_densepose_V`: has value 0
* `loss_densepose_I`: has value 0
"""
conf_type = self.confidence_model_cfg.uv_confidence.type
if self.confidence_model_cfg.uv_confidence.enabled:
loss_uv = (
densepose_predictor_outputs.u.sum() + densepose_predictor_outputs.v.sum()
) * 0
if conf_type == DensePoseUVConfidenceType.IID_ISO:
loss_uv += densepose_predictor_outputs.sigma_2.sum() * 0
elif conf_type == DensePoseUVConfidenceType.INDEP_ANISO:
loss_uv += (
densepose_predictor_outputs.sigma_2.sum()
+ densepose_predictor_outputs.kappa_u.sum()
+ densepose_predictor_outputs.kappa_v.sum()
) * 0
return {"loss_densepose_UV": loss_uv}
else:
return super().produce_fake_densepose_losses_uv(densepose_predictor_outputs)
def produce_densepose_losses_uv(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
packed_annotations: Any,
interpolator: BilinearInterpolationHelper,
j_valid_fg: torch.Tensor,
) -> LossDict:
conf_type = self.confidence_model_cfg.uv_confidence.type
if self.confidence_model_cfg.uv_confidence.enabled:
u_gt = packed_annotations.u_gt[j_valid_fg]
u_est = interpolator.extract_at_points(densepose_predictor_outputs.u)[j_valid_fg]
v_gt = packed_annotations.v_gt[j_valid_fg]
v_est = interpolator.extract_at_points(densepose_predictor_outputs.v)[j_valid_fg]
sigma_2_est = interpolator.extract_at_points(densepose_predictor_outputs.sigma_2)[
j_valid_fg
]
if conf_type == DensePoseUVConfidenceType.IID_ISO:
return {
"loss_densepose_UV": (
self.uv_loss_with_confidences(u_est, v_est, sigma_2_est, u_gt, v_gt)
* self.w_points
)
}
elif conf_type in [DensePoseUVConfidenceType.INDEP_ANISO]:
kappa_u_est = interpolator.extract_at_points(densepose_predictor_outputs.kappa_u)[
j_valid_fg
]
kappa_v_est = interpolator.extract_at_points(densepose_predictor_outputs.kappa_v)[
j_valid_fg
]
return {
"loss_densepose_UV": (
self.uv_loss_with_confidences(
u_est, v_est, sigma_2_est, kappa_u_est, kappa_v_est, u_gt, v_gt
)
* self.w_points
)
}
return super().produce_densepose_losses_uv(
proposals_with_gt,
densepose_predictor_outputs,
packed_annotations,
interpolator,
j_valid_fg,
)
class IIDIsotropicGaussianUVLoss(nn.Module):
"""
Loss for the case of iid residuals with isotropic covariance:
$Sigma_i = sigma_i^2 I$
The loss (negative log likelihood) is then:
$1/2 sum_{i=1}^n (log(2 pi) + 2 log sigma_i^2 + ||delta_i||^2 / sigma_i^2)$,
where $delta_i=(u - u', v - v')$ is a 2D vector containing UV coordinates
difference between estimated and ground truth UV values
For details, see:
N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning
Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019
"""
def __init__(self, sigma_lower_bound: float):
super(IIDIsotropicGaussianUVLoss, self).__init__()
self.sigma_lower_bound = sigma_lower_bound
self.log2pi = math.log(2 * math.pi)
def forward(
self,
u: torch.Tensor,
v: torch.Tensor,
sigma_u: torch.Tensor,
target_u: torch.Tensor,
target_v: torch.Tensor,
):
# compute $\sigma_i^2$
# use sigma_lower_bound to avoid degenerate solution for variance
# (sigma -> 0)
sigma2 = F.softplus(sigma_u) + self.sigma_lower_bound
# compute \|delta_i\|^2
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
delta_t_delta = (u - target_u) ** 2 + (v - target_v) ** 2
# the total loss from the formula above:
loss = 0.5 * (self.log2pi + 2 * torch.log(sigma2) + delta_t_delta / sigma2)
return loss.sum()
class IndepAnisotropicGaussianUVLoss(nn.Module):
"""
Loss for the case of independent residuals with anisotropic covariances:
$Sigma_i = sigma_i^2 I + r_i r_i^T$
The loss (negative log likelihood) is then:
$1/2 sum_{i=1}^n (log(2 pi)
+ log sigma_i^2 (sigma_i^2 + ||r_i||^2)
+ ||delta_i||^2 / sigma_i^2
- <delta_i, r_i>^2 / (sigma_i^2 * (sigma_i^2 + ||r_i||^2)))$,
where $delta_i=(u - u', v - v')$ is a 2D vector containing UV coordinates
difference between estimated and ground truth UV values
For details, see:
N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning
Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019
"""
def __init__(self, sigma_lower_bound: float):
super(IndepAnisotropicGaussianUVLoss, self).__init__()
self.sigma_lower_bound = sigma_lower_bound
self.log2pi = math.log(2 * math.pi)
def forward(
self,
u: torch.Tensor,
v: torch.Tensor,
sigma_u: torch.Tensor,
kappa_u_est: torch.Tensor,
kappa_v_est: torch.Tensor,
target_u: torch.Tensor,
target_v: torch.Tensor,
):
# compute $\sigma_i^2$
sigma2 = F.softplus(sigma_u) + self.sigma_lower_bound
# compute \|r_i\|^2
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
r_sqnorm2 = kappa_u_est**2 + kappa_v_est**2
delta_u = u - target_u
delta_v = v - target_v
# compute \|delta_i\|^2
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
delta_sqnorm = delta_u**2 + delta_v**2
delta_u_r_u = delta_u * kappa_u_est
delta_v_r_v = delta_v * kappa_v_est
# compute the scalar product <delta_i, r_i>
delta_r = delta_u_r_u + delta_v_r_v
# compute squared scalar product <delta_i, r_i>^2
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
delta_r_sqnorm = delta_r**2
denom2 = sigma2 * (sigma2 + r_sqnorm2)
loss = 0.5 * (
self.log2pi + torch.log(denom2) + delta_sqnorm / sigma2 - delta_r_sqnorm / denom2
)
return loss.sum()
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from typing import Any, List
from torch import nn
from detectron2.config import CfgNode
from detectron2.structures import Instances
from .cycle_pix2shape import PixToShapeCycleLoss
from .cycle_shape2shape import ShapeToShapeCycleLoss
from .embed import EmbeddingLoss
from .embed_utils import CseAnnotationsAccumulator
from .mask_or_segm import MaskOrSegmentationLoss
from .registry import DENSEPOSE_LOSS_REGISTRY
from .soft_embed import SoftEmbeddingLoss
from .utils import BilinearInterpolationHelper, LossDict, extract_packed_annotations_from_matches
@DENSEPOSE_LOSS_REGISTRY.register()
class DensePoseCseLoss:
""" """
_EMBED_LOSS_REGISTRY = {
EmbeddingLoss.__name__: EmbeddingLoss,
SoftEmbeddingLoss.__name__: SoftEmbeddingLoss,
}
def __init__(self, cfg: CfgNode):
"""
Initialize CSE loss from configuration options
Args:
cfg (CfgNode): configuration options
"""
self.w_segm = cfg.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS
self.w_embed = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT
self.segm_loss = MaskOrSegmentationLoss(cfg)
self.embed_loss = DensePoseCseLoss.create_embed_loss(cfg)
self.do_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.ENABLED
if self.do_shape2shape:
self.w_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT
self.shape2shape_loss = ShapeToShapeCycleLoss(cfg)
self.do_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.ENABLED
if self.do_pix2shape:
self.w_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT
self.pix2shape_loss = PixToShapeCycleLoss(cfg)
@classmethod
def create_embed_loss(cls, cfg: CfgNode):
# registry not used here, since embedding losses are currently local
# and are not used anywhere else
return cls._EMBED_LOSS_REGISTRY[cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME](cfg)
def __call__(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
embedder: nn.Module,
) -> LossDict:
if not len(proposals_with_gt):
return self.produce_fake_losses(densepose_predictor_outputs, embedder)
accumulator = CseAnnotationsAccumulator()
packed_annotations = extract_packed_annotations_from_matches(proposals_with_gt, accumulator)
if packed_annotations is None:
return self.produce_fake_losses(densepose_predictor_outputs, embedder)
h, w = densepose_predictor_outputs.embedding.shape[2:]
interpolator = BilinearInterpolationHelper.from_matches(
packed_annotations,
(h, w),
)
meshid_to_embed_losses = self.embed_loss(
proposals_with_gt,
densepose_predictor_outputs,
packed_annotations,
interpolator,
embedder,
)
embed_loss_dict = {
f"loss_densepose_E{meshid}": self.w_embed * meshid_to_embed_losses[meshid]
for meshid in meshid_to_embed_losses
}
all_loss_dict = {
"loss_densepose_S": self.w_segm
* self.segm_loss(proposals_with_gt, densepose_predictor_outputs, packed_annotations),
**embed_loss_dict,
}
if self.do_shape2shape:
all_loss_dict["loss_shape2shape"] = self.w_shape2shape * self.shape2shape_loss(embedder)
if self.do_pix2shape:
all_loss_dict["loss_pix2shape"] = self.w_pix2shape * self.pix2shape_loss(
proposals_with_gt, densepose_predictor_outputs, packed_annotations, embedder
)
return all_loss_dict
def produce_fake_losses(
self, densepose_predictor_outputs: Any, embedder: nn.Module
) -> LossDict:
meshname_to_embed_losses = self.embed_loss.fake_values(
densepose_predictor_outputs, embedder=embedder
)
embed_loss_dict = {
f"loss_densepose_E{mesh_name}": meshname_to_embed_losses[mesh_name]
for mesh_name in meshname_to_embed_losses
}
all_loss_dict = {
"loss_densepose_S": self.segm_loss.fake_value(densepose_predictor_outputs),
**embed_loss_dict,
}
if self.do_shape2shape:
all_loss_dict["loss_shape2shape"] = self.shape2shape_loss.fake_value(embedder)
if self.do_pix2shape:
all_loss_dict["loss_pix2shape"] = self.pix2shape_loss.fake_value(
densepose_predictor_outputs, embedder
)
return all_loss_dict
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from typing import Any, List
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.structures import Instances
from densepose.data.meshes.catalog import MeshCatalog
from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix
from .embed_utils import PackedCseAnnotations
from .mask import extract_data_for_mask_loss_from_matches
def _create_pixel_dist_matrix(grid_size: int) -> torch.Tensor:
rows = torch.arange(grid_size)
cols = torch.arange(grid_size)
# at index `i` contains [row, col], where
# row = i // grid_size
# col = i % grid_size
pix_coords = (
torch.stack(torch.meshgrid(rows, cols), -1).reshape((grid_size * grid_size, 2)).float()
)
return squared_euclidean_distance_matrix(pix_coords, pix_coords)
def _sample_fg_pixels_randperm(fg_mask: torch.Tensor, sample_size: int) -> torch.Tensor:
fg_mask_flattened = fg_mask.reshape((-1,))
num_pixels = int(fg_mask_flattened.sum().item())
fg_pixel_indices = fg_mask_flattened.nonzero(as_tuple=True)[0]
if (sample_size <= 0) or (num_pixels <= sample_size):
return fg_pixel_indices
sample_indices = torch.randperm(num_pixels, device=fg_mask.device)[:sample_size]
return fg_pixel_indices[sample_indices]
def _sample_fg_pixels_multinomial(fg_mask: torch.Tensor, sample_size: int) -> torch.Tensor:
fg_mask_flattened = fg_mask.reshape((-1,))
num_pixels = int(fg_mask_flattened.sum().item())
if (sample_size <= 0) or (num_pixels <= sample_size):
return fg_mask_flattened.nonzero(as_tuple=True)[0]
return fg_mask_flattened.float().multinomial(sample_size, replacement=False)
class PixToShapeCycleLoss(nn.Module):
"""
Cycle loss for pixel-vertex correspondence
"""
def __init__(self, cfg: CfgNode):
super().__init__()
self.shape_names = list(cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.keys())
self.embed_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
self.norm_p = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P
self.use_all_meshes_not_gt_only = (
cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY
)
self.num_pixels_to_sample = (
cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE
)
self.pix_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA
self.temperature_pix_to_vertex = (
cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX
)
self.temperature_vertex_to_pix = (
cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL
)
self.pixel_dists = _create_pixel_dist_matrix(cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE)
def forward(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
packed_annotations: PackedCseAnnotations,
embedder: nn.Module,
):
"""
Args:
proposals_with_gt (list of Instances): detections with associated
ground truth data; each item corresponds to instances detected
on 1 image; the number of items corresponds to the number of
images in a batch
densepose_predictor_outputs: an object of a dataclass that contains predictor
outputs with estimated values; assumed to have the following attributes:
* embedding - embedding estimates, tensor of shape [N, D, S, S], where
N = number of instances (= sum N_i, where N_i is the number of
instances on image i)
D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE)
S = output size (width and height)
packed_annotations (PackedCseAnnotations): contains various data useful
for loss computation, each data is packed into a single tensor
embedder (nn.Module): module that computes vertex embeddings for different meshes
"""
pix_embeds = densepose_predictor_outputs.embedding
if self.pixel_dists.device != pix_embeds.device:
# should normally be done only once
self.pixel_dists = self.pixel_dists.to(device=pix_embeds.device)
with torch.no_grad():
mask_loss_data = extract_data_for_mask_loss_from_matches(
proposals_with_gt, densepose_predictor_outputs.coarse_segm
)
# GT masks - tensor of shape [N, S, S] of int64
masks_gt = mask_loss_data.masks_gt.long() # pyre-ignore[16]
assert len(pix_embeds) == len(masks_gt), (
f"Number of instances with embeddings {len(pix_embeds)} != "
f"number of instances with GT masks {len(masks_gt)}"
)
losses = []
mesh_names = (
self.shape_names
if self.use_all_meshes_not_gt_only
else [
MeshCatalog.get_mesh_name(mesh_id.item())
for mesh_id in packed_annotations.vertex_mesh_ids_gt.unique()
]
)
for pixel_embeddings, mask_gt in zip(pix_embeds, masks_gt):
# pixel_embeddings [D, S, S]
# mask_gt [S, S]
for mesh_name in mesh_names:
mesh_vertex_embeddings = embedder(mesh_name)
# pixel indices [M]
pixel_indices_flattened = _sample_fg_pixels_randperm(
mask_gt, self.num_pixels_to_sample
)
# pixel distances [M, M]
pixel_dists = self.pixel_dists.to(pixel_embeddings.device)[
torch.meshgrid(pixel_indices_flattened, pixel_indices_flattened)
]
# pixel embeddings [M, D]
pixel_embeddings_sampled = normalize_embeddings(
pixel_embeddings.reshape((self.embed_size, -1))[:, pixel_indices_flattened].T
)
# pixel-vertex similarity [M, K]
sim_matrix = pixel_embeddings_sampled.mm(mesh_vertex_embeddings.T)
c_pix_vertex = F.softmax(sim_matrix / self.temperature_pix_to_vertex, dim=1)
c_vertex_pix = F.softmax(sim_matrix.T / self.temperature_vertex_to_pix, dim=1)
c_cycle = c_pix_vertex.mm(c_vertex_pix)
loss_cycle = torch.norm(pixel_dists * c_cycle, p=self.norm_p)
losses.append(loss_cycle)
if len(losses) == 0:
return pix_embeds.sum() * 0
return torch.stack(losses, dim=0).mean()
def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module):
losses = [embedder(mesh_name).sum() * 0 for mesh_name in embedder.mesh_names]
losses.append(densepose_predictor_outputs.embedding.sum() * 0)
return torch.mean(torch.stack(losses))
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
import random
from typing import Tuple
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from densepose.structures.mesh import create_mesh
from .utils import sample_random_indices
class ShapeToShapeCycleLoss(nn.Module):
"""
Cycle Loss for Shapes.
Inspired by:
"Mapping in a Cycle: Sinkhorn Regularized Unsupervised Learning for Point Cloud Shapes".
"""
def __init__(self, cfg: CfgNode):
super().__init__()
self.shape_names = list(cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.keys())
self.all_shape_pairs = [
(x, y) for i, x in enumerate(self.shape_names) for y in self.shape_names[i + 1 :]
]
random.shuffle(self.all_shape_pairs)
self.cur_pos = 0
self.norm_p = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P
self.temperature = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE
self.max_num_vertices = (
cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES
)
def _sample_random_pair(self) -> Tuple[str, str]:
"""
Produce a random pair of different mesh names
Return:
tuple(str, str): a pair of different mesh names
"""
if self.cur_pos >= len(self.all_shape_pairs):
random.shuffle(self.all_shape_pairs)
self.cur_pos = 0
shape_pair = self.all_shape_pairs[self.cur_pos]
self.cur_pos += 1
return shape_pair
def forward(self, embedder: nn.Module):
"""
Do a forward pass with a random pair (src, dst) pair of shapes
Args:
embedder (nn.Module): module that computes vertex embeddings for different meshes
"""
src_mesh_name, dst_mesh_name = self._sample_random_pair()
return self._forward_one_pair(embedder, src_mesh_name, dst_mesh_name)
def fake_value(self, embedder: nn.Module):
losses = []
for mesh_name in embedder.mesh_names:
losses.append(embedder(mesh_name).sum() * 0)
return torch.mean(torch.stack(losses))
def _get_embeddings_and_geodists_for_mesh(
self, embedder: nn.Module, mesh_name: str
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Produces embeddings and geodesic distance tensors for a given mesh. May subsample
the mesh, if it contains too many vertices (controlled by
SHAPE_CYCLE_LOSS_MAX_NUM_VERTICES parameter).
Args:
embedder (nn.Module): module that computes embeddings for mesh vertices
mesh_name (str): mesh name
Return:
embeddings (torch.Tensor of size [N, D]): embeddings for selected mesh
vertices (N = number of selected vertices, D = embedding space dim)
geodists (torch.Tensor of size [N, N]): geodesic distances for the selected
mesh vertices (N = number of selected vertices)
"""
embeddings = embedder(mesh_name)
indices = sample_random_indices(
embeddings.shape[0], self.max_num_vertices, embeddings.device
)
mesh = create_mesh(mesh_name, embeddings.device)
geodists = mesh.geodists
if indices is not None:
embeddings = embeddings[indices]
geodists = geodists[torch.meshgrid(indices, indices)]
return embeddings, geodists
def _forward_one_pair(
self, embedder: nn.Module, mesh_name_1: str, mesh_name_2: str
) -> torch.Tensor:
"""
Do a forward pass with a selected pair of meshes
Args:
embedder (nn.Module): module that computes vertex embeddings for different meshes
mesh_name_1 (str): first mesh name
mesh_name_2 (str): second mesh name
Return:
Tensor containing the loss value
"""
embeddings_1, geodists_1 = self._get_embeddings_and_geodists_for_mesh(embedder, mesh_name_1)
embeddings_2, geodists_2 = self._get_embeddings_and_geodists_for_mesh(embedder, mesh_name_2)
sim_matrix_12 = embeddings_1.mm(embeddings_2.T)
c_12 = F.softmax(sim_matrix_12 / self.temperature, dim=1)
c_21 = F.softmax(sim_matrix_12.T / self.temperature, dim=1)
c_11 = c_12.mm(c_21)
c_22 = c_21.mm(c_12)
loss_cycle_11 = torch.norm(geodists_1 * c_11, p=self.norm_p)
loss_cycle_22 = torch.norm(geodists_2 * c_22, p=self.norm_p)
return loss_cycle_11 + loss_cycle_22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from typing import Any, Dict, List
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.structures import Instances
from densepose.data.meshes.catalog import MeshCatalog
from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix
from .embed_utils import PackedCseAnnotations
from .utils import BilinearInterpolationHelper
class EmbeddingLoss:
"""
Computes losses for estimated embeddings given annotated vertices.
Instances in a minibatch that correspond to the same mesh are grouped
together. For each group, loss is computed as cross-entropy for
unnormalized scores given ground truth mesh vertex ids.
Scores are based on squared distances between estimated vertex embeddings
and mesh vertex embeddings.
"""
def __init__(self, cfg: CfgNode):
"""
Initialize embedding loss from config
"""
self.embdist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA
def __call__(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
packed_annotations: PackedCseAnnotations,
interpolator: BilinearInterpolationHelper,
embedder: nn.Module,
) -> Dict[int, torch.Tensor]:
"""
Produces losses for estimated embeddings given annotated vertices.
Embeddings for all the vertices of a mesh are computed by the embedder.
Embeddings for observed pixels are estimated by a predictor.
Losses are computed as cross-entropy for squared distances between
observed vertex embeddings and all mesh vertex embeddings given
ground truth vertex IDs.
Args:
proposals_with_gt (list of Instances): detections with associated
ground truth data; each item corresponds to instances detected
on 1 image; the number of items corresponds to the number of
images in a batch
densepose_predictor_outputs: an object of a dataclass that contains predictor
outputs with estimated values; assumed to have the following attributes:
* embedding - embedding estimates, tensor of shape [N, D, S, S], where
N = number of instances (= sum N_i, where N_i is the number of
instances on image i)
D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE)
S = output size (width and height)
packed_annotations (PackedCseAnnotations): contains various data useful
for loss computation, each data is packed into a single tensor
interpolator (BilinearInterpolationHelper): bilinear interpolation helper
embedder (nn.Module): module that computes vertex embeddings for different meshes
Return:
dict(int -> tensor): losses for different mesh IDs
"""
losses = {}
for mesh_id_tensor in packed_annotations.vertex_mesh_ids_gt.unique():
mesh_id = mesh_id_tensor.item()
mesh_name = MeshCatalog.get_mesh_name(mesh_id)
# valid points are those that fall into estimated bbox
# and correspond to the current mesh
j_valid = interpolator.j_valid * ( # pyre-ignore[16]
packed_annotations.vertex_mesh_ids_gt == mesh_id
)
if not torch.any(j_valid):
continue
# extract estimated embeddings for valid points
# -> tensor [J, D]
vertex_embeddings_i = normalize_embeddings(
interpolator.extract_at_points(
densepose_predictor_outputs.embedding,
slice_fine_segm=slice(None),
w_ylo_xlo=interpolator.w_ylo_xlo[:, None], # pyre-ignore[16]
w_ylo_xhi=interpolator.w_ylo_xhi[:, None], # pyre-ignore[16]
w_yhi_xlo=interpolator.w_yhi_xlo[:, None], # pyre-ignore[16]
w_yhi_xhi=interpolator.w_yhi_xhi[:, None], # pyre-ignore[16]
)[j_valid, :]
)
# extract vertex ids for valid points
# -> tensor [J]
vertex_indices_i = packed_annotations.vertex_ids_gt[j_valid]
# embeddings for all mesh vertices
# -> tensor [K, D]
mesh_vertex_embeddings = embedder(mesh_name)
# unnormalized scores for valid points
# -> tensor [J, K]
scores = squared_euclidean_distance_matrix(
vertex_embeddings_i, mesh_vertex_embeddings
) / (-self.embdist_gauss_sigma)
losses[mesh_name] = F.cross_entropy(scores, vertex_indices_i, ignore_index=-1)
for mesh_name in embedder.mesh_names:
if mesh_name not in losses:
losses[mesh_name] = self.fake_value(
densepose_predictor_outputs, embedder, mesh_name
)
return losses
def fake_values(self, densepose_predictor_outputs: Any, embedder: nn.Module):
losses = {}
for mesh_name in embedder.mesh_names:
losses[mesh_name] = self.fake_value(densepose_predictor_outputs, embedder, mesh_name)
return losses
def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module, mesh_name: str):
return densepose_predictor_outputs.embedding.sum() * 0 + embedder(mesh_name).sum() * 0
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from dataclasses import dataclass
from typing import Any, Optional
import torch
from detectron2.structures import BoxMode, Instances
from .utils import AnnotationsAccumulator
@dataclass
class PackedCseAnnotations:
x_gt: torch.Tensor
y_gt: torch.Tensor
coarse_segm_gt: Optional[torch.Tensor]
vertex_mesh_ids_gt: torch.Tensor
vertex_ids_gt: torch.Tensor
bbox_xywh_gt: torch.Tensor
bbox_xywh_est: torch.Tensor
point_bbox_with_dp_indices: torch.Tensor
point_bbox_indices: torch.Tensor
bbox_indices: torch.Tensor
class CseAnnotationsAccumulator(AnnotationsAccumulator):
"""
Accumulates annotations by batches that correspond to objects detected on
individual images. Can pack them together into single tensors.
"""
def __init__(self):
self.x_gt = []
self.y_gt = []
self.s_gt = []
self.vertex_mesh_ids_gt = []
self.vertex_ids_gt = []
self.bbox_xywh_gt = []
self.bbox_xywh_est = []
self.point_bbox_with_dp_indices = []
self.point_bbox_indices = []
self.bbox_indices = []
self.nxt_bbox_with_dp_index = 0
self.nxt_bbox_index = 0
def accumulate(self, instances_one_image: Instances):
"""
Accumulate instances data for one image
Args:
instances_one_image (Instances): instances data to accumulate
"""
boxes_xywh_est = BoxMode.convert(
instances_one_image.proposal_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS
)
boxes_xywh_gt = BoxMode.convert(
instances_one_image.gt_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS
)
n_matches = len(boxes_xywh_gt)
assert n_matches == len(
boxes_xywh_est
), f"Got {len(boxes_xywh_est)} proposal boxes and {len(boxes_xywh_gt)} GT boxes"
if not n_matches:
# no detection - GT matches
return
if (
not hasattr(instances_one_image, "gt_densepose")
or instances_one_image.gt_densepose is None
):
# no densepose GT for the detections, just increase the bbox index
self.nxt_bbox_index += n_matches
return
for box_xywh_est, box_xywh_gt, dp_gt in zip(
boxes_xywh_est, boxes_xywh_gt, instances_one_image.gt_densepose
):
if (dp_gt is not None) and (len(dp_gt.x) > 0):
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `float`.
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`.
self._do_accumulate(box_xywh_gt, box_xywh_est, dp_gt)
self.nxt_bbox_index += 1
def _do_accumulate(self, box_xywh_gt: torch.Tensor, box_xywh_est: torch.Tensor, dp_gt: Any):
"""
Accumulate instances data for one image, given that the data is not empty
Args:
box_xywh_gt (tensor): GT bounding box
box_xywh_est (tensor): estimated bounding box
dp_gt: GT densepose data with the following attributes:
- x: normalized X coordinates
- y: normalized Y coordinates
- segm: tensor of size [S, S] with coarse segmentation
-
"""
self.x_gt.append(dp_gt.x)
self.y_gt.append(dp_gt.y)
if hasattr(dp_gt, "segm"):
self.s_gt.append(dp_gt.segm.unsqueeze(0))
self.vertex_ids_gt.append(dp_gt.vertex_ids)
self.vertex_mesh_ids_gt.append(torch.full_like(dp_gt.vertex_ids, dp_gt.mesh_id))
self.bbox_xywh_gt.append(box_xywh_gt.view(-1, 4))
self.bbox_xywh_est.append(box_xywh_est.view(-1, 4))
self.point_bbox_with_dp_indices.append(
torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_with_dp_index)
)
self.point_bbox_indices.append(torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_index))
self.bbox_indices.append(self.nxt_bbox_index)
self.nxt_bbox_with_dp_index += 1
def pack(self) -> Optional[PackedCseAnnotations]:
"""
Pack data into tensors
"""
if not len(self.x_gt):
# TODO:
# returning proper empty annotations would require
# creating empty tensors of appropriate shape and
# type on an appropriate device;
# we return None so far to indicate empty annotations
return None
return PackedCseAnnotations(
x_gt=torch.cat(self.x_gt, 0),
y_gt=torch.cat(self.y_gt, 0),
vertex_mesh_ids_gt=torch.cat(self.vertex_mesh_ids_gt, 0),
vertex_ids_gt=torch.cat(self.vertex_ids_gt, 0),
# ignore segmentation annotations, if not all the instances contain those
coarse_segm_gt=(
torch.cat(self.s_gt, 0) if len(self.s_gt) == len(self.bbox_xywh_gt) else None
),
bbox_xywh_gt=torch.cat(self.bbox_xywh_gt, 0),
bbox_xywh_est=torch.cat(self.bbox_xywh_est, 0),
point_bbox_with_dp_indices=torch.cat(self.point_bbox_with_dp_indices, 0),
point_bbox_indices=torch.cat(self.point_bbox_indices, 0),
bbox_indices=torch.as_tensor(
self.bbox_indices, dtype=torch.long, device=self.x_gt[0].device
),
)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional
import torch
from torch.nn import functional as F
from detectron2.structures import Instances
@dataclass
class DataForMaskLoss:
"""
Contains mask GT and estimated data for proposals from multiple images:
"""
# tensor of size (K, H, W) containing GT labels
masks_gt: Optional[torch.Tensor] = None
# tensor of size (K, C, H, W) containing estimated scores
masks_est: Optional[torch.Tensor] = None
def extract_data_for_mask_loss_from_matches(
proposals_targets: Iterable[Instances], estimated_segm: torch.Tensor
) -> DataForMaskLoss:
"""
Extract data for mask loss from instances that contain matched GT and
estimated bounding boxes.
Args:
proposals_targets: Iterable[Instances]
matched GT and estimated results, each item in the iterable
corresponds to data in 1 image
estimated_segm: tensor(K, C, S, S) of float - raw unnormalized
segmentation scores, here S is the size to which GT masks are
to be resized
Return:
masks_est: tensor(K, C, S, S) of float - class scores
masks_gt: tensor(K, S, S) of int64 - labels
"""
data = DataForMaskLoss()
masks_gt = []
offset = 0
assert estimated_segm.shape[2] == estimated_segm.shape[3], (
f"Expected estimated segmentation to have a square shape, "
f"but the actual shape is {estimated_segm.shape[2:]}"
)
mask_size = estimated_segm.shape[2]
num_proposals = sum(inst.proposal_boxes.tensor.size(0) for inst in proposals_targets)
num_estimated = estimated_segm.shape[0]
assert (
num_proposals == num_estimated
), "The number of proposals {} must be equal to the number of estimates {}".format(
num_proposals, num_estimated
)
for proposals_targets_per_image in proposals_targets:
n_i = proposals_targets_per_image.proposal_boxes.tensor.size(0)
if not n_i:
continue
gt_masks_per_image = proposals_targets_per_image.gt_masks.crop_and_resize(
proposals_targets_per_image.proposal_boxes.tensor, mask_size
).to(device=estimated_segm.device)
masks_gt.append(gt_masks_per_image)
offset += n_i
if masks_gt:
data.masks_est = estimated_segm
data.masks_gt = torch.cat(masks_gt, dim=0)
return data
class MaskLoss:
"""
Mask loss as cross-entropy for raw unnormalized scores given ground truth labels.
Mask ground truth labels are defined for the whole image and not only the
bounding box of interest. They are stored as objects that are assumed to implement
the `crop_and_resize` interface (e.g. BitMasks, PolygonMasks).
"""
def __call__(
self, proposals_with_gt: List[Instances], densepose_predictor_outputs: Any
) -> torch.Tensor:
"""
Computes segmentation loss as cross-entropy for raw unnormalized
scores given ground truth labels.
Args:
proposals_with_gt (list of Instances): detections with associated ground truth data
densepose_predictor_outputs: an object of a dataclass that contains predictor outputs
with estimated values; assumed to have the following attribute:
* coarse_segm (tensor of shape [N, D, S, S]): coarse segmentation estimates
as raw unnormalized scores
where N is the number of detections, S is the estimate size ( = width = height)
and D is the number of coarse segmentation channels.
Return:
Cross entropy for raw unnormalized scores for coarse segmentation given
ground truth labels from masks
"""
if not len(proposals_with_gt):
return self.fake_value(densepose_predictor_outputs)
# densepose outputs are computed for all images and all bounding boxes;
# i.e. if a batch has 4 images with (3, 1, 2, 1) proposals respectively,
# the outputs will have size(0) == 3+1+2+1 == 7
with torch.no_grad():
mask_loss_data = extract_data_for_mask_loss_from_matches(
proposals_with_gt, densepose_predictor_outputs.coarse_segm
)
if (mask_loss_data.masks_gt is None) or (mask_loss_data.masks_est is None):
return self.fake_value(densepose_predictor_outputs)
return F.cross_entropy(mask_loss_data.masks_est, mask_loss_data.masks_gt.long())
def fake_value(self, densepose_predictor_outputs: Any) -> torch.Tensor:
"""
Fake segmentation loss used when no suitable ground truth data
was found in a batch. The loss has a value 0 and is primarily used to
construct the computation graph, so that `DistributedDataParallel`
has similar graphs on all GPUs and can perform reduction properly.
Args:
densepose_predictor_outputs: DensePose predictor outputs, an object
of a dataclass that is assumed to have `coarse_segm`
attribute
Return:
Zero value loss with proper computation graph
"""
return densepose_predictor_outputs.coarse_segm.sum() * 0
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from typing import Any, List
import torch
from detectron2.config import CfgNode
from detectron2.structures import Instances
from .mask import MaskLoss
from .segm import SegmentationLoss
class MaskOrSegmentationLoss:
"""
Mask or segmentation loss as cross-entropy for raw unnormalized scores
given ground truth labels. Ground truth labels are either defined by coarse
segmentation annotation, or by mask annotation, depending on the config
value MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
"""
def __init__(self, cfg: CfgNode):
"""
Initialize segmentation loss from configuration options
Args:
cfg (CfgNode): configuration options
"""
self.segm_trained_by_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
if self.segm_trained_by_masks:
self.mask_loss = MaskLoss()
self.segm_loss = SegmentationLoss(cfg)
def __call__(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
packed_annotations: Any,
) -> torch.Tensor:
"""
Compute segmentation loss as cross-entropy between aligned unnormalized
score estimates and ground truth; with ground truth given
either by masks, or by coarse segmentation annotations.
Args:
proposals_with_gt (list of Instances): detections with associated ground truth data
densepose_predictor_outputs: an object of a dataclass that contains predictor outputs
with estimated values; assumed to have the following attributes:
* coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S]
packed_annotations: packed annotations for efficient loss computation
Return:
tensor: loss value as cross-entropy for raw unnormalized scores
given ground truth labels
"""
if self.segm_trained_by_masks:
return self.mask_loss(proposals_with_gt, densepose_predictor_outputs)
return self.segm_loss(proposals_with_gt, densepose_predictor_outputs, packed_annotations)
def fake_value(self, densepose_predictor_outputs: Any) -> torch.Tensor:
"""
Fake segmentation loss used when no suitable ground truth data
was found in a batch. The loss has a value 0 and is primarily used to
construct the computation graph, so that `DistributedDataParallel`
has similar graphs on all GPUs and can perform reduction properly.
Args:
densepose_predictor_outputs: DensePose predictor outputs, an object
of a dataclass that is assumed to have `coarse_segm`
attribute
Return:
Zero value loss with proper computation graph
"""
return densepose_predictor_outputs.coarse_segm.sum() * 0
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from detectron2.utils.registry import Registry
DENSEPOSE_LOSS_REGISTRY = Registry("DENSEPOSE_LOSS")
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from typing import Any, List
import torch
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.structures import Instances
from .utils import resample_data
class SegmentationLoss:
"""
Segmentation loss as cross-entropy for raw unnormalized scores given ground truth
labels. Segmentation ground truth labels are defined for the bounding box of
interest at some fixed resolution [S, S], where
S = MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE.
"""
def __init__(self, cfg: CfgNode):
"""
Initialize segmentation loss from configuration options
Args:
cfg (CfgNode): configuration options
"""
self.heatmap_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE
self.n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS
def __call__(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
packed_annotations: Any,
) -> torch.Tensor:
"""
Compute segmentation loss as cross-entropy on aligned segmentation
ground truth and estimated scores.
Args:
proposals_with_gt (list of Instances): detections with associated ground truth data
densepose_predictor_outputs: an object of a dataclass that contains predictor outputs
with estimated values; assumed to have the following attributes:
* coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S]
packed_annotations: packed annotations for efficient loss computation;
the following attributes are used:
- coarse_segm_gt
- bbox_xywh_gt
- bbox_xywh_est
"""
if packed_annotations.coarse_segm_gt is None:
return self.fake_value(densepose_predictor_outputs)
coarse_segm_est = densepose_predictor_outputs.coarse_segm[packed_annotations.bbox_indices]
with torch.no_grad():
coarse_segm_gt = resample_data(
packed_annotations.coarse_segm_gt.unsqueeze(1),
packed_annotations.bbox_xywh_gt,
packed_annotations.bbox_xywh_est,
self.heatmap_size,
self.heatmap_size,
mode="nearest",
padding_mode="zeros",
).squeeze(1)
if self.n_segm_chan == 2:
coarse_segm_gt = coarse_segm_gt > 0
return F.cross_entropy(coarse_segm_est, coarse_segm_gt.long())
def fake_value(self, densepose_predictor_outputs: Any) -> torch.Tensor:
"""
Fake segmentation loss used when no suitable ground truth data
was found in a batch. The loss has a value 0 and is primarily used to
construct the computation graph, so that `DistributedDataParallel`
has similar graphs on all GPUs and can perform reduction properly.
Args:
densepose_predictor_outputs: DensePose predictor outputs, an object
of a dataclass that is assumed to have `coarse_segm`
attribute
Return:
Zero value loss with proper computation graph
"""
return densepose_predictor_outputs.coarse_segm.sum() * 0
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from typing import Any, Dict, List
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.structures import Instances
from densepose.data.meshes.catalog import MeshCatalog
from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix
from densepose.structures.mesh import create_mesh
from .embed_utils import PackedCseAnnotations
from .utils import BilinearInterpolationHelper
class SoftEmbeddingLoss:
"""
Computes losses for estimated embeddings given annotated vertices.
Instances in a minibatch that correspond to the same mesh are grouped
together. For each group, loss is computed as cross-entropy for
unnormalized scores given ground truth mesh vertex ids.
Scores are based on:
1) squared distances between estimated vertex embeddings
and mesh vertex embeddings;
2) geodesic distances between vertices of a mesh
"""
def __init__(self, cfg: CfgNode):
"""
Initialize embedding loss from config
"""
self.embdist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA
self.geodist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA
def __call__(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
packed_annotations: PackedCseAnnotations,
interpolator: BilinearInterpolationHelper,
embedder: nn.Module,
) -> Dict[int, torch.Tensor]:
"""
Produces losses for estimated embeddings given annotated vertices.
Embeddings for all the vertices of a mesh are computed by the embedder.
Embeddings for observed pixels are estimated by a predictor.
Losses are computed as cross-entropy for unnormalized scores given
ground truth vertex IDs.
1) squared distances between estimated vertex embeddings
and mesh vertex embeddings;
2) geodesic distances between vertices of a mesh
Args:
proposals_with_gt (list of Instances): detections with associated
ground truth data; each item corresponds to instances detected
on 1 image; the number of items corresponds to the number of
images in a batch
densepose_predictor_outputs: an object of a dataclass that contains predictor
outputs with estimated values; assumed to have the following attributes:
* embedding - embedding estimates, tensor of shape [N, D, S, S], where
N = number of instances (= sum N_i, where N_i is the number of
instances on image i)
D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE)
S = output size (width and height)
packed_annotations (PackedCseAnnotations): contains various data useful
for loss computation, each data is packed into a single tensor
interpolator (BilinearInterpolationHelper): bilinear interpolation helper
embedder (nn.Module): module that computes vertex embeddings for different meshes
Return:
dict(int -> tensor): losses for different mesh IDs
"""
losses = {}
for mesh_id_tensor in packed_annotations.vertex_mesh_ids_gt.unique():
mesh_id = mesh_id_tensor.item()
mesh_name = MeshCatalog.get_mesh_name(mesh_id)
# valid points are those that fall into estimated bbox
# and correspond to the current mesh
j_valid = interpolator.j_valid * ( # pyre-ignore[16]
packed_annotations.vertex_mesh_ids_gt == mesh_id
)
if not torch.any(j_valid):
continue
# extract estimated embeddings for valid points
# -> tensor [J, D]
vertex_embeddings_i = normalize_embeddings(
interpolator.extract_at_points(
densepose_predictor_outputs.embedding,
slice_fine_segm=slice(None),
w_ylo_xlo=interpolator.w_ylo_xlo[:, None], # pyre-ignore[16]
w_ylo_xhi=interpolator.w_ylo_xhi[:, None], # pyre-ignore[16]
w_yhi_xlo=interpolator.w_yhi_xlo[:, None], # pyre-ignore[16]
w_yhi_xhi=interpolator.w_yhi_xhi[:, None], # pyre-ignore[16]
)[j_valid, :]
)
# extract vertex ids for valid points
# -> tensor [J]
vertex_indices_i = packed_annotations.vertex_ids_gt[j_valid]
# embeddings for all mesh vertices
# -> tensor [K, D]
mesh_vertex_embeddings = embedder(mesh_name)
# softmax values of geodesic distances for GT mesh vertices
# -> tensor [J, K]
mesh = create_mesh(mesh_name, mesh_vertex_embeddings.device)
geodist_softmax_values = F.softmax(
mesh.geodists[vertex_indices_i] / (-self.geodist_gauss_sigma), dim=1
)
# logsoftmax values for valid points
# -> tensor [J, K]
embdist_logsoftmax_values = F.log_softmax(
squared_euclidean_distance_matrix(vertex_embeddings_i, mesh_vertex_embeddings)
/ (-self.embdist_gauss_sigma),
dim=1,
)
losses[mesh_name] = (-geodist_softmax_values * embdist_logsoftmax_values).sum(1).mean()
for mesh_name in embedder.mesh_names:
if mesh_name not in losses:
losses[mesh_name] = self.fake_value(
densepose_predictor_outputs, embedder, mesh_name
)
return losses
def fake_values(self, densepose_predictor_outputs: Any, embedder: nn.Module):
losses = {}
for mesh_name in embedder.mesh_names:
losses[mesh_name] = self.fake_value(densepose_predictor_outputs, embedder, mesh_name)
return losses
def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module, mesh_name: str):
return densepose_predictor_outputs.embedding.sum() * 0 + embedder(mesh_name).sum() * 0
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.nn import functional as F
from detectron2.structures import BoxMode, Instances
from densepose import DensePoseDataRelative
LossDict = Dict[str, torch.Tensor]
def _linear_interpolation_utilities(v_norm, v0_src, size_src, v0_dst, size_dst, size_z):
"""
Computes utility values for linear interpolation at points v.
The points are given as normalized offsets in the source interval
(v0_src, v0_src + size_src), more precisely:
v = v0_src + v_norm * size_src / 256.0
The computed utilities include lower points v_lo, upper points v_hi,
interpolation weights v_w and flags j_valid indicating whether the
points falls into the destination interval (v0_dst, v0_dst + size_dst).
Args:
v_norm (:obj: `torch.Tensor`): tensor of size N containing
normalized point offsets
v0_src (:obj: `torch.Tensor`): tensor of size N containing
left bounds of source intervals for normalized points
size_src (:obj: `torch.Tensor`): tensor of size N containing
source interval sizes for normalized points
v0_dst (:obj: `torch.Tensor`): tensor of size N containing
left bounds of destination intervals
size_dst (:obj: `torch.Tensor`): tensor of size N containing
destination interval sizes
size_z (int): interval size for data to be interpolated
Returns:
v_lo (:obj: `torch.Tensor`): int tensor of size N containing
indices of lower values used for interpolation, all values are
integers from [0, size_z - 1]
v_hi (:obj: `torch.Tensor`): int tensor of size N containing
indices of upper values used for interpolation, all values are
integers from [0, size_z - 1]
v_w (:obj: `torch.Tensor`): float tensor of size N containing
interpolation weights
j_valid (:obj: `torch.Tensor`): uint8 tensor of size N containing
0 for points outside the estimation interval
(v0_est, v0_est + size_est) and 1 otherwise
"""
v = v0_src + v_norm * size_src / 256.0
j_valid = (v - v0_dst >= 0) * (v - v0_dst < size_dst)
v_grid = (v - v0_dst) * size_z / size_dst
v_lo = v_grid.floor().long().clamp(min=0, max=size_z - 1)
v_hi = (v_lo + 1).clamp(max=size_z - 1)
v_grid = torch.min(v_hi.float(), v_grid)
v_w = v_grid - v_lo.float()
return v_lo, v_hi, v_w, j_valid
class BilinearInterpolationHelper:
"""
Args:
packed_annotations: object that contains packed annotations
j_valid (:obj: `torch.Tensor`): uint8 tensor of size M containing
0 for points to be discarded and 1 for points to be selected
y_lo (:obj: `torch.Tensor`): int tensor of indices of upper values
in z_est for each point
y_hi (:obj: `torch.Tensor`): int tensor of indices of lower values
in z_est for each point
x_lo (:obj: `torch.Tensor`): int tensor of indices of left values
in z_est for each point
x_hi (:obj: `torch.Tensor`): int tensor of indices of right values
in z_est for each point
w_ylo_xlo (:obj: `torch.Tensor`): float tensor of size M;
contains upper-left value weight for each point
w_ylo_xhi (:obj: `torch.Tensor`): float tensor of size M;
contains upper-right value weight for each point
w_yhi_xlo (:obj: `torch.Tensor`): float tensor of size M;
contains lower-left value weight for each point
w_yhi_xhi (:obj: `torch.Tensor`): float tensor of size M;
contains lower-right value weight for each point
"""
def __init__(
self,
packed_annotations: Any,
j_valid: torch.Tensor,
y_lo: torch.Tensor,
y_hi: torch.Tensor,
x_lo: torch.Tensor,
x_hi: torch.Tensor,
w_ylo_xlo: torch.Tensor,
w_ylo_xhi: torch.Tensor,
w_yhi_xlo: torch.Tensor,
w_yhi_xhi: torch.Tensor,
):
for k, v in locals().items():
if k != "self":
setattr(self, k, v)
@staticmethod
def from_matches(
packed_annotations: Any, densepose_outputs_size_hw: Tuple[int, int]
) -> "BilinearInterpolationHelper":
"""
Args:
packed_annotations: annotations packed into tensors, the following
attributes are required:
- bbox_xywh_gt
- bbox_xywh_est
- x_gt
- y_gt
- point_bbox_with_dp_indices
- point_bbox_indices
densepose_outputs_size_hw (tuple [int, int]): resolution of
DensePose predictor outputs (H, W)
Return:
An instance of `BilinearInterpolationHelper` used to perform
interpolation for the given annotation points and output resolution
"""
zh, zw = densepose_outputs_size_hw
x0_gt, y0_gt, w_gt, h_gt = packed_annotations.bbox_xywh_gt[
packed_annotations.point_bbox_with_dp_indices
].unbind(dim=1)
x0_est, y0_est, w_est, h_est = packed_annotations.bbox_xywh_est[
packed_annotations.point_bbox_with_dp_indices
].unbind(dim=1)
x_lo, x_hi, x_w, jx_valid = _linear_interpolation_utilities(
packed_annotations.x_gt, x0_gt, w_gt, x0_est, w_est, zw
)
y_lo, y_hi, y_w, jy_valid = _linear_interpolation_utilities(
packed_annotations.y_gt, y0_gt, h_gt, y0_est, h_est, zh
)
j_valid = jx_valid * jy_valid
w_ylo_xlo = (1.0 - x_w) * (1.0 - y_w)
w_ylo_xhi = x_w * (1.0 - y_w)
w_yhi_xlo = (1.0 - x_w) * y_w
w_yhi_xhi = x_w * y_w
return BilinearInterpolationHelper(
packed_annotations,
j_valid,
y_lo,
y_hi,
x_lo,
x_hi,
w_ylo_xlo, # pyre-ignore[6]
w_ylo_xhi,
# pyre-fixme[6]: Expected `Tensor` for 9th param but got `float`.
w_yhi_xlo,
w_yhi_xhi,
)
def extract_at_points(
self,
z_est,
slice_fine_segm=None,
w_ylo_xlo=None,
w_ylo_xhi=None,
w_yhi_xlo=None,
w_yhi_xhi=None,
):
"""
Extract ground truth values z_gt for valid point indices and estimated
values z_est using bilinear interpolation over top-left (y_lo, x_lo),
top-right (y_lo, x_hi), bottom-left (y_hi, x_lo) and bottom-right
(y_hi, x_hi) values in z_est with corresponding weights:
w_ylo_xlo, w_ylo_xhi, w_yhi_xlo and w_yhi_xhi.
Use slice_fine_segm to slice dim=1 in z_est
"""
slice_fine_segm = (
self.packed_annotations.fine_segm_labels_gt
if slice_fine_segm is None
else slice_fine_segm
)
w_ylo_xlo = self.w_ylo_xlo if w_ylo_xlo is None else w_ylo_xlo
w_ylo_xhi = self.w_ylo_xhi if w_ylo_xhi is None else w_ylo_xhi
w_yhi_xlo = self.w_yhi_xlo if w_yhi_xlo is None else w_yhi_xlo
w_yhi_xhi = self.w_yhi_xhi if w_yhi_xhi is None else w_yhi_xhi
index_bbox = self.packed_annotations.point_bbox_indices
z_est_sampled = (
z_est[index_bbox, slice_fine_segm, self.y_lo, self.x_lo] * w_ylo_xlo
+ z_est[index_bbox, slice_fine_segm, self.y_lo, self.x_hi] * w_ylo_xhi
+ z_est[index_bbox, slice_fine_segm, self.y_hi, self.x_lo] * w_yhi_xlo
+ z_est[index_bbox, slice_fine_segm, self.y_hi, self.x_hi] * w_yhi_xhi
)
return z_est_sampled
def resample_data(
z, bbox_xywh_src, bbox_xywh_dst, wout, hout, mode: str = "nearest", padding_mode: str = "zeros"
):
"""
Args:
z (:obj: `torch.Tensor`): tensor of size (N,C,H,W) with data to be
resampled
bbox_xywh_src (:obj: `torch.Tensor`): tensor of size (N,4) containing
source bounding boxes in format XYWH
bbox_xywh_dst (:obj: `torch.Tensor`): tensor of size (N,4) containing
destination bounding boxes in format XYWH
Return:
zresampled (:obj: `torch.Tensor`): tensor of size (N, C, Hout, Wout)
with resampled values of z, where D is the discretization size
"""
n = bbox_xywh_src.size(0)
assert n == bbox_xywh_dst.size(0), (
"The number of "
"source ROIs for resampling ({}) should be equal to the number "
"of destination ROIs ({})".format(bbox_xywh_src.size(0), bbox_xywh_dst.size(0))
)
x0src, y0src, wsrc, hsrc = bbox_xywh_src.unbind(dim=1)
x0dst, y0dst, wdst, hdst = bbox_xywh_dst.unbind(dim=1)
x0dst_norm = 2 * (x0dst - x0src) / wsrc - 1
y0dst_norm = 2 * (y0dst - y0src) / hsrc - 1
x1dst_norm = 2 * (x0dst + wdst - x0src) / wsrc - 1
y1dst_norm = 2 * (y0dst + hdst - y0src) / hsrc - 1
grid_w = torch.arange(wout, device=z.device, dtype=torch.float) / wout
grid_h = torch.arange(hout, device=z.device, dtype=torch.float) / hout
grid_w_expanded = grid_w[None, None, :].expand(n, hout, wout)
grid_h_expanded = grid_h[None, :, None].expand(n, hout, wout)
dx_expanded = (x1dst_norm - x0dst_norm)[:, None, None].expand(n, hout, wout)
dy_expanded = (y1dst_norm - y0dst_norm)[:, None, None].expand(n, hout, wout)
x0_expanded = x0dst_norm[:, None, None].expand(n, hout, wout)
y0_expanded = y0dst_norm[:, None, None].expand(n, hout, wout)
grid_x = grid_w_expanded * dx_expanded + x0_expanded
grid_y = grid_h_expanded * dy_expanded + y0_expanded
grid = torch.stack((grid_x, grid_y), dim=3)
# resample Z from (N, C, H, W) into (N, C, Hout, Wout)
zresampled = F.grid_sample(z, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
return zresampled
class AnnotationsAccumulator(ABC):
"""
Abstract class for an accumulator for annotations that can produce
dense annotations packed into tensors.
"""
@abstractmethod
def accumulate(self, instances_one_image: Instances):
"""
Accumulate instances data for one image
Args:
instances_one_image (Instances): instances data to accumulate
"""
pass
@abstractmethod
def pack(self) -> Any:
"""
Pack data into tensors
"""
pass
@dataclass
class PackedChartBasedAnnotations:
"""
Packed annotations for chart-based model training. The following attributes
are defined:
- fine_segm_labels_gt (tensor [K] of `int64`): GT fine segmentation point labels
- x_gt (tensor [K] of `float32`): GT normalized X point coordinates
- y_gt (tensor [K] of `float32`): GT normalized Y point coordinates
- u_gt (tensor [K] of `float32`): GT point U values
- v_gt (tensor [K] of `float32`): GT point V values
- coarse_segm_gt (tensor [N, S, S] of `float32`): GT segmentation for bounding boxes
- bbox_xywh_gt (tensor [N, 4] of `float32`): selected GT bounding boxes in
XYWH format
- bbox_xywh_est (tensor [N, 4] of `float32`): selected matching estimated
bounding boxes in XYWH format
- point_bbox_with_dp_indices (tensor [K] of `int64`): indices of bounding boxes
with DensePose annotations that correspond to the point data
- point_bbox_indices (tensor [K] of `int64`): indices of bounding boxes
(not necessarily the selected ones with DensePose data) that correspond
to the point data
- bbox_indices (tensor [N] of `int64`): global indices of selected bounding
boxes with DensePose annotations; these indices could be used to access
features that are computed for all bounding boxes, not only the ones with
DensePose annotations.
Here K is the total number of points and N is the total number of instances
with DensePose annotations.
"""
fine_segm_labels_gt: torch.Tensor
x_gt: torch.Tensor
y_gt: torch.Tensor
u_gt: torch.Tensor
v_gt: torch.Tensor
coarse_segm_gt: Optional[torch.Tensor]
bbox_xywh_gt: torch.Tensor
bbox_xywh_est: torch.Tensor
point_bbox_with_dp_indices: torch.Tensor
point_bbox_indices: torch.Tensor
bbox_indices: torch.Tensor
class ChartBasedAnnotationsAccumulator(AnnotationsAccumulator):
"""
Accumulates annotations by batches that correspond to objects detected on
individual images. Can pack them together into single tensors.
"""
def __init__(self):
self.i_gt = []
self.x_gt = []
self.y_gt = []
self.u_gt = []
self.v_gt = []
self.s_gt = []
self.bbox_xywh_gt = []
self.bbox_xywh_est = []
self.point_bbox_with_dp_indices = []
self.point_bbox_indices = []
self.bbox_indices = []
self.nxt_bbox_with_dp_index = 0
self.nxt_bbox_index = 0
def accumulate(self, instances_one_image: Instances):
"""
Accumulate instances data for one image
Args:
instances_one_image (Instances): instances data to accumulate
"""
boxes_xywh_est = BoxMode.convert(
instances_one_image.proposal_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS
)
boxes_xywh_gt = BoxMode.convert(
instances_one_image.gt_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS
)
n_matches = len(boxes_xywh_gt)
assert n_matches == len(
boxes_xywh_est
), f"Got {len(boxes_xywh_est)} proposal boxes and {len(boxes_xywh_gt)} GT boxes"
if not n_matches:
# no detection - GT matches
return
if (
not hasattr(instances_one_image, "gt_densepose")
or instances_one_image.gt_densepose is None
):
# no densepose GT for the detections, just increase the bbox index
self.nxt_bbox_index += n_matches
return
for box_xywh_est, box_xywh_gt, dp_gt in zip(
boxes_xywh_est, boxes_xywh_gt, instances_one_image.gt_densepose
):
if (dp_gt is not None) and (len(dp_gt.x) > 0):
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `float`.
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`.
self._do_accumulate(box_xywh_gt, box_xywh_est, dp_gt)
self.nxt_bbox_index += 1
def _do_accumulate(
self, box_xywh_gt: torch.Tensor, box_xywh_est: torch.Tensor, dp_gt: DensePoseDataRelative
):
"""
Accumulate instances data for one image, given that the data is not empty
Args:
box_xywh_gt (tensor): GT bounding box
box_xywh_est (tensor): estimated bounding box
dp_gt (DensePoseDataRelative): GT densepose data
"""
self.i_gt.append(dp_gt.i)
self.x_gt.append(dp_gt.x)
self.y_gt.append(dp_gt.y)
self.u_gt.append(dp_gt.u)
self.v_gt.append(dp_gt.v)
if hasattr(dp_gt, "segm"):
self.s_gt.append(dp_gt.segm.unsqueeze(0))
self.bbox_xywh_gt.append(box_xywh_gt.view(-1, 4))
self.bbox_xywh_est.append(box_xywh_est.view(-1, 4))
self.point_bbox_with_dp_indices.append(
torch.full_like(dp_gt.i, self.nxt_bbox_with_dp_index)
)
self.point_bbox_indices.append(torch.full_like(dp_gt.i, self.nxt_bbox_index))
self.bbox_indices.append(self.nxt_bbox_index)
self.nxt_bbox_with_dp_index += 1
def pack(self) -> Optional[PackedChartBasedAnnotations]:
"""
Pack data into tensors
"""
if not len(self.i_gt):
# TODO:
# returning proper empty annotations would require
# creating empty tensors of appropriate shape and
# type on an appropriate device;
# we return None so far to indicate empty annotations
return None
return PackedChartBasedAnnotations(
fine_segm_labels_gt=torch.cat(self.i_gt, 0).long(),
x_gt=torch.cat(self.x_gt, 0),
y_gt=torch.cat(self.y_gt, 0),
u_gt=torch.cat(self.u_gt, 0),
v_gt=torch.cat(self.v_gt, 0),
# ignore segmentation annotations, if not all the instances contain those
coarse_segm_gt=(
torch.cat(self.s_gt, 0) if len(self.s_gt) == len(self.bbox_xywh_gt) else None
),
bbox_xywh_gt=torch.cat(self.bbox_xywh_gt, 0),
bbox_xywh_est=torch.cat(self.bbox_xywh_est, 0),
point_bbox_with_dp_indices=torch.cat(self.point_bbox_with_dp_indices, 0).long(),
point_bbox_indices=torch.cat(self.point_bbox_indices, 0).long(),
bbox_indices=torch.as_tensor(
self.bbox_indices, dtype=torch.long, device=self.x_gt[0].device
).long(),
)
def extract_packed_annotations_from_matches(
proposals_with_targets: List[Instances], accumulator: AnnotationsAccumulator
) -> Any:
for proposals_targets_per_image in proposals_with_targets:
accumulator.accumulate(proposals_targets_per_image)
return accumulator.pack()
def sample_random_indices(
n_indices: int, n_samples: int, device: Optional[torch.device] = None
) -> Optional[torch.Tensor]:
"""
Samples `n_samples` random indices from range `[0..n_indices - 1]`.
If `n_indices` is smaller than `n_samples`, returns `None` meaning that all indices
are selected.
Args:
n_indices (int): total number of indices
n_samples (int): number of indices to sample
device (torch.device): the desired device of returned tensor
Return:
Tensor of selected vertex indices, or `None`, if all vertices are selected
"""
if (n_samples <= 0) or (n_indices <= n_samples):
return None
indices = torch.randperm(n_indices, device=device)[:n_samples]
return indices
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from .chart import DensePoseChartPredictor
from .chart_confidence import DensePoseChartConfidencePredictorMixin
from .chart_with_confidence import DensePoseChartWithConfidencePredictor
from .cse import DensePoseEmbeddingPredictor
from .cse_confidence import DensePoseEmbeddingConfidencePredictorMixin
from .cse_with_confidence import DensePoseEmbeddingWithConfidencePredictor
from .registry import DENSEPOSE_PREDICTOR_REGISTRY
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import torch
from torch import nn
from detectron2.config import CfgNode
from detectron2.layers import ConvTranspose2d, interpolate
from ...structures import DensePoseChartPredictorOutput
from ..utils import initialize_module_params
from .registry import DENSEPOSE_PREDICTOR_REGISTRY
@DENSEPOSE_PREDICTOR_REGISTRY.register()
class DensePoseChartPredictor(nn.Module):
"""
Predictor (last layers of a DensePose model) that takes DensePose head outputs as an input
and produces 4 tensors which represent DensePose results for predefined body parts
(patches / charts):
* coarse segmentation, a tensor of shape [N, K, Hout, Wout]
* fine segmentation, a tensor of shape [N, C, Hout, Wout]
* U coordinates, a tensor of shape [N, C, Hout, Wout]
* V coordinates, a tensor of shape [N, C, Hout, Wout]
where
- N is the number of instances
- K is the number of coarse segmentation channels (
2 = foreground / background,
15 = one of 14 body parts / background)
- C is the number of fine segmentation channels (
24 fine body parts / background)
- Hout and Wout are height and width of predictions
"""
def __init__(self, cfg: CfgNode, input_channels: int):
"""
Initialize predictor using configuration options
Args:
cfg (CfgNode): configuration options
input_channels (int): input tensor size along the channel dimension
"""
super().__init__()
dim_in = input_channels
n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS
dim_out_patches = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES + 1
kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL
# coarse segmentation
self.ann_index_lowres = ConvTranspose2d(
dim_in, n_segm_chan, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
# fine segmentation
self.index_uv_lowres = ConvTranspose2d(
dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
# U
self.u_lowres = ConvTranspose2d(
dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
# V
self.v_lowres = ConvTranspose2d(
dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
self.scale_factor = cfg.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE
initialize_module_params(self)
def interp2d(self, tensor_nchw: torch.Tensor):
"""
Bilinear interpolation method to be used for upscaling
Args:
tensor_nchw (tensor): tensor of shape (N, C, H, W)
Return:
tensor of shape (N, C, Hout, Wout), where Hout and Wout are computed
by applying the scale factor to H and W
"""
return interpolate(
tensor_nchw, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
)
def forward(self, head_outputs: torch.Tensor):
"""
Perform forward step on DensePose head outputs
Args:
head_outputs (tensor): DensePose head outputs, tensor of shape [N, D, H, W]
Return:
An instance of DensePoseChartPredictorOutput
"""
return DensePoseChartPredictorOutput(
coarse_segm=self.interp2d(self.ann_index_lowres(head_outputs)),
fine_segm=self.interp2d(self.index_uv_lowres(head_outputs)),
u=self.interp2d(self.u_lowres(head_outputs)),
v=self.interp2d(self.v_lowres(head_outputs)),
)
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from typing import Any
import torch
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.layers import ConvTranspose2d
from ...structures import decorate_predictor_output_class_with_confidences
from ..confidence import DensePoseConfidenceModelConfig, DensePoseUVConfidenceType
from ..utils import initialize_module_params
class DensePoseChartConfidencePredictorMixin:
"""
Predictor contains the last layers of a DensePose model that take DensePose head
outputs as an input and produce model outputs. Confidence predictor mixin is used
to generate confidences for segmentation and UV tensors estimated by some
base predictor. Several assumptions need to hold for the base predictor:
1) the `forward` method must return SIUV tuple as the first result (
S = coarse segmentation, I = fine segmentation, U and V are intrinsic
chart coordinates)
2) `interp2d` method must be defined to perform bilinear interpolation;
the same method is typically used for SIUV and confidences
Confidence predictor mixin provides confidence estimates, as described in:
N. Neverova et al., Correlated Uncertainty for Learning Dense Correspondences
from Noisy Labels, NeurIPS 2019
A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020
"""
def __init__(self, cfg: CfgNode, input_channels: int):
"""
Initialize confidence predictor using configuration options.
Args:
cfg (CfgNode): configuration options
input_channels (int): number of input channels
"""
# we rely on base predictor to call nn.Module.__init__
super().__init__(cfg, input_channels) # pyre-ignore[19]
self.confidence_model_cfg = DensePoseConfidenceModelConfig.from_cfg(cfg)
self._initialize_confidence_estimation_layers(cfg, input_channels)
self._registry = {}
initialize_module_params(self) # pyre-ignore[6]
def _initialize_confidence_estimation_layers(self, cfg: CfgNode, dim_in: int):
"""
Initialize confidence estimation layers based on configuration options
Args:
cfg (CfgNode): configuration options
dim_in (int): number of input channels
"""
dim_out_patches = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES + 1
kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL
if self.confidence_model_cfg.uv_confidence.enabled:
if self.confidence_model_cfg.uv_confidence.type == DensePoseUVConfidenceType.IID_ISO:
self.sigma_2_lowres = ConvTranspose2d( # pyre-ignore[16]
dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
elif (
self.confidence_model_cfg.uv_confidence.type
== DensePoseUVConfidenceType.INDEP_ANISO
):
self.sigma_2_lowres = ConvTranspose2d(
dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
self.kappa_u_lowres = ConvTranspose2d( # pyre-ignore[16]
dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
self.kappa_v_lowres = ConvTranspose2d( # pyre-ignore[16]
dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
else:
raise ValueError(
f"Unknown confidence model type: "
f"{self.confidence_model_cfg.confidence_model_type}"
)
if self.confidence_model_cfg.segm_confidence.enabled:
self.fine_segm_confidence_lowres = ConvTranspose2d( # pyre-ignore[16]
dim_in, 1, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
self.coarse_segm_confidence_lowres = ConvTranspose2d( # pyre-ignore[16]
dim_in, 1, kernel_size, stride=2, padding=int(kernel_size / 2 - 1)
)
def forward(self, head_outputs: torch.Tensor):
"""
Perform forward operation on head outputs used as inputs for the predictor.
Calls forward method from the base predictor and uses its outputs to compute
confidences.
Args:
head_outputs (Tensor): head outputs used as predictor inputs
Return:
An instance of outputs with confidences,
see `decorate_predictor_output_class_with_confidences`
"""
# assuming base class returns SIUV estimates in its first result
base_predictor_outputs = super().forward(head_outputs) # pyre-ignore[16]
# create output instance by extending base predictor outputs:
output = self._create_output_instance(base_predictor_outputs)
if self.confidence_model_cfg.uv_confidence.enabled:
if self.confidence_model_cfg.uv_confidence.type == DensePoseUVConfidenceType.IID_ISO:
# assuming base class defines interp2d method for bilinear interpolation
output.sigma_2 = self.interp2d(self.sigma_2_lowres(head_outputs)) # pyre-ignore[16]
elif (
self.confidence_model_cfg.uv_confidence.type
== DensePoseUVConfidenceType.INDEP_ANISO
):
# assuming base class defines interp2d method for bilinear interpolation
output.sigma_2 = self.interp2d(self.sigma_2_lowres(head_outputs))
output.kappa_u = self.interp2d(self.kappa_u_lowres(head_outputs)) # pyre-ignore[16]
output.kappa_v = self.interp2d(self.kappa_v_lowres(head_outputs)) # pyre-ignore[16]
else:
raise ValueError(
f"Unknown confidence model type: "
f"{self.confidence_model_cfg.confidence_model_type}"
)
if self.confidence_model_cfg.segm_confidence.enabled:
# base predictor outputs are assumed to have `fine_segm` and `coarse_segm` attributes
# base predictor is assumed to define `interp2d` method for bilinear interpolation
output.fine_segm_confidence = (
F.softplus(
self.interp2d(self.fine_segm_confidence_lowres(head_outputs)) # pyre-ignore[16]
)
+ self.confidence_model_cfg.segm_confidence.epsilon
)
output.fine_segm = base_predictor_outputs.fine_segm * torch.repeat_interleave(
output.fine_segm_confidence, base_predictor_outputs.fine_segm.shape[1], dim=1
)
output.coarse_segm_confidence = (
F.softplus(
self.interp2d(
self.coarse_segm_confidence_lowres(head_outputs) # pyre-ignore[16]
)
)
+ self.confidence_model_cfg.segm_confidence.epsilon
)
output.coarse_segm = base_predictor_outputs.coarse_segm * torch.repeat_interleave(
output.coarse_segm_confidence, base_predictor_outputs.coarse_segm.shape[1], dim=1
)
return output
def _create_output_instance(self, base_predictor_outputs: Any):
"""
Create an instance of predictor outputs by copying the outputs from the
base predictor and initializing confidence
Args:
base_predictor_outputs: an instance of base predictor outputs
(the outputs type is assumed to be a dataclass)
Return:
An instance of outputs with confidences
"""
PredictorOutput = decorate_predictor_output_class_with_confidences(
type(base_predictor_outputs) # pyre-ignore[6]
)
# base_predictor_outputs is assumed to be a dataclass
# reassign all the fields from base_predictor_outputs (no deep copy!), add new fields
output = PredictorOutput(
**base_predictor_outputs.__dict__,
coarse_segm_confidence=None,
fine_segm_confidence=None,
sigma_1=None,
sigma_2=None,
kappa_u=None,
kappa_v=None,
)
return output
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from . import DensePoseChartConfidencePredictorMixin, DensePoseChartPredictor
from .registry import DENSEPOSE_PREDICTOR_REGISTRY
@DENSEPOSE_PREDICTOR_REGISTRY.register()
class DensePoseChartWithConfidencePredictor(
DensePoseChartConfidencePredictorMixin, DensePoseChartPredictor
):
"""
Predictor that combines chart and chart confidence estimation
"""
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment