Unverified Commit 90071fe4 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Improve DETR models (#19644)

* Improve DETR models

* Fix Deformable DETR loss and matcher

* Fixup

* Fix integration tests

* Improve variable names

* Apply suggestion

* Fix copies

* Fix DeformableDetrLoss

* Make Conditional DETR copy from Deformable DETR

* Copy from deformable detr's hungarian matcher

* Fix bug
parent 072dfdae
......@@ -23,7 +23,7 @@ The abstract from the paper is the following:
Tips:
- One can use the [`AutoFeatureExtractor`] API to prepare images (and optional targets) for the model. This will instantiate a [`DetrFeatureExtractor`] behind the scenes.
- One can use [`DeformableDetrFeatureExtractor`] to prepare images (and optional targets) for the model.
- Training Deformable DETR is equivalent to training the original [DETR](detr) model. Demo notebooks can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/DETR).
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/deformable_detr_architecture.png"
......
......@@ -110,6 +110,8 @@ class ConditionalDetrConfig(PretrainedConfig):
Relative weight of the generalized IoU loss in the object detection loss.
eos_coefficient (`float`, *optional*, defaults to 0.1):
Relative classification weight of the 'no-object' class in the object detection loss.
focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss.
Examples:
......
......@@ -44,8 +44,8 @@ def center_to_corners_format(x):
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
......
......@@ -114,6 +114,8 @@ class DeformableDetrConfig(PretrainedConfig):
with_box_refine (`bool`, *optional*, defaults to `False`):
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
based on the predictions from the previous layer.
focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss.
Examples:
......@@ -174,6 +176,7 @@ class DeformableDetrConfig(PretrainedConfig):
bbox_loss_coefficient=5,
giou_loss_coefficient=2,
eos_coefficient=0.1,
focal_alpha=0.25,
**kwargs
):
self.num_queries = num_queries
......@@ -216,6 +219,7 @@ class DeformableDetrConfig(PretrainedConfig):
self.bbox_loss_coefficient = bbox_loss_coefficient
self.giou_loss_coefficient = giou_loss_coefficient
self.eos_coefficient = eos_coefficient
self.focal_alpha = focal_alpha
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
@property
......
......@@ -44,8 +44,8 @@ def center_to_corners_format(x):
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
......
......@@ -35,7 +35,6 @@ from ...file_utils import (
is_scipy_available,
is_timm_available,
is_torch_cuda_available,
is_vision_available,
replace_return_docstrings,
requires_backends,
)
......@@ -111,9 +110,6 @@ class MultiScaleDeformableAttentionFunction(Function):
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
if is_vision_available():
from transformers.models.detr.feature_extraction_detr import center_to_corners_format
if is_timm_available():
from timm import create_model
......@@ -1952,7 +1948,7 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
criterion = DeformableDetrLoss(
matcher=matcher,
num_classes=self.config.num_labels,
eos_coef=self.config.eos_coefficient,
focal_alpha=self.config.focal_alpha,
losses=losses,
)
criterion.to(self.device)
......@@ -2065,46 +2061,38 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
return loss.mean(1).sum() / num_boxes
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
class DeformableDetrLoss(nn.Module):
"""
This class computes the losses for DeformableDetrForObjectDetection. The process happens in two steps: 1) we
This class computes the losses for `DeformableDetrForObjectDetection`. The process happens in two steps: 1) we
compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of
matched ground-truth / prediction (supervise class and box)
matched ground-truth / prediction (supervise class and box).
Args:
matcher (`DeformableDetrHungarianMatcher`):
Module able to compute a matching between targets and proposals.
num_classes (`int`):
Number of object categories, omitting the special no-object category.
focal_alpha (`float`):
Alpha parameter in focal loss.
losses (`List[str]`):
List of all the losses to be applied. See `get_loss` for a list of all available losses.
"""
def __init__(self, matcher, num_classes, eos_coef, losses, focal_alpha=0.25):
"""
Create the criterion.
A note on the num_classes parameter (copied from original repo in detr.py): "the naming of the `num_classes`
parameter of the criterion is somewhat misleading. it indeed corresponds to `max_obj_id + 1`, where max_obj_id
is the maximum id for a class in your dataset. For example, COCO has a max_obj_id of 90, so we pass
`num_classes` to be 91. As another example, for a dataset that has a single class with id 1, you should pass
`num_classes` to be 2 (max_obj_id + 1). For more details on this, check the following discussion
https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
Parameters:
matcher: module able to compute a matching between targets and proposals.
num_classes: number of object categories, omitting the special no-object category.
eos_coef: relative classification weight applied to the no-object category.
losses: list of all the losses to be applied. See get_loss for list of available losses.
focal_alpha: alpha in Focal Loss.
"""
def __init__(self, matcher, num_classes, focal_alpha, losses):
super().__init__()
self.matcher = matcher
self.num_classes = num_classes
self.losses = losses
self.focal_alpha = focal_alpha
self.losses = losses
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
# removed logging parameter, which was part of the original implementation
def loss_labels(self, outputs, targets, indices, num_boxes):
"""
Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
of dim [nb_target_boxes]
"""
if "logits" not in outputs:
raise ValueError("No logits were found in the outputs")
raise KeyError("No logits were found in the outputs")
source_logits = outputs["logits"]
idx = self._get_source_permutation_idx(indices)
......@@ -2132,6 +2120,7 @@ class DeformableDetrLoss(nn.Module):
return losses
@torch.no_grad()
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_cardinality
def loss_cardinality(self, outputs, targets, indices, num_boxes):
"""
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
......@@ -2147,6 +2136,7 @@ class DeformableDetrLoss(nn.Module):
losses = {"cardinality_error": card_err}
return losses
# Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_boxes
def loss_boxes(self, outputs, targets, indices, num_boxes):
"""
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
......@@ -2155,8 +2145,7 @@ class DeformableDetrLoss(nn.Module):
are expected in format (center_x, center_y, w, h), normalized by the image size.
"""
if "pred_boxes" not in outputs:
raise ValueError("No predicted boxes found in outputs")
raise KeyError("No predicted boxes found in outputs")
idx = self._get_source_permutation_idx(indices)
source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
......@@ -2172,12 +2161,14 @@ class DeformableDetrLoss(nn.Module):
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
# Copied from transformers.models.detr.modeling_detr.DetrLoss._get_source_permutation_idx
def _get_source_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, source_idx
# Copied from transformers.models.detr.modeling_detr.DetrLoss._get_target_permutation_idx
def _get_target_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
......@@ -2192,17 +2183,18 @@ class DeformableDetrLoss(nn.Module):
}
if loss not in loss_map:
raise ValueError(f"Loss {loss} not supported")
return loss_map[loss](outputs, targets, indices, num_boxes)
def forward(self, outputs, targets):
"""
This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
Args:
outputs (`dict`, *optional*):
Dictionary of tensors, see the output specification of the model for the format.
targets (`List[dict]`, *optional*):
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
......@@ -2272,7 +2264,6 @@ class DeformableDetrMLPPredictionHead(nn.Module):
return x
# Copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher
class DeformableDetrHungarianMatcher(nn.Module):
"""
This class computes an assignment between the targets and the predictions of the network.
......@@ -2324,17 +2315,19 @@ class DeformableDetrHungarianMatcher(nn.Module):
batch_size, num_queries = outputs["logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
target_ids = torch.cat([v["class_labels"] for v in targets])
target_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
class_cost = -out_prob[:, target_ids]
# Compute the classification cost.
alpha = 0.25
gamma = 2.0
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
# Compute the L1 cost between boxes
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
......@@ -2419,6 +2412,17 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area
# Copied from transformers.models.detr.modeling_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.modeling_detr._max_by_axis
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
......
......@@ -44,8 +44,8 @@ def center_to_corners_format(x):
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
......
......@@ -33,7 +33,6 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_scipy_available,
is_timm_available,
is_vision_available,
logging,
replace_return_docstrings,
requires_backends,
......@@ -44,9 +43,6 @@ from .configuration_detr import DetrConfig
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
if is_vision_available():
from .feature_extraction_detr import center_to_corners_format
if is_timm_available():
from timm import create_model
......@@ -1964,16 +1960,16 @@ class DetrLoss(nn.Module):
"""
if "logits" not in outputs:
raise KeyError("No logits were found in the outputs")
src_logits = outputs["logits"]
source_logits = outputs["logits"]
idx = self._get_src_permutation_idx(indices)
idx = self._get_source_permutation_idx(indices)
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
)
target_classes[idx] = target_classes_o
loss_ce = nn.functional.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {"loss_ce": loss_ce}
return losses
......@@ -2003,17 +1999,17 @@ class DetrLoss(nn.Module):
"""
if "pred_boxes" not in outputs:
raise KeyError("No predicted boxes found in outputs")
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx]
idx = self._get_source_permutation_idx(indices)
source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none")
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
)
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
......@@ -2027,41 +2023,41 @@ class DetrLoss(nn.Module):
if "pred_masks" not in outputs:
raise KeyError("No predicted masks found in outputs")
src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)
src_masks = outputs["pred_masks"]
src_masks = src_masks[src_idx]
source_idx = self._get_source_permutation_idx(indices)
target_idx = self._get_target_permutation_idx(indices)
source_masks = outputs["pred_masks"]
source_masks = source_masks[source_idx]
masks = [t["masks"] for t in targets]
# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks)
target_masks = target_masks[tgt_idx]
target_masks = target_masks.to(source_masks)
target_masks = target_masks[target_idx]
# upsample predictions to the target size
src_masks = nn.functional.interpolate(
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
source_masks = nn.functional.interpolate(
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
)
src_masks = src_masks[:, 0].flatten(1)
source_masks = source_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape)
target_masks = target_masks.view(source_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
}
return losses
def _get_src_permutation_idx(self, indices):
def _get_source_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, source_idx
def _get_tgt_permutation_idx(self, indices):
def _get_target_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
target_idx = torch.cat([target for (_, target) in indices])
return batch_idx, target_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes):
loss_map = {
......@@ -2082,7 +2078,7 @@ class DetrLoss(nn.Module):
outputs (`dict`, *optional*):
Dictionary of tensors, see the output specification of the model for the format.
targets (`List[dict]`, *optional*):
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
......@@ -2288,6 +2284,17 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area
# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306
......
......@@ -42,8 +42,8 @@ def center_to_corners_format(x):
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
......
......@@ -959,16 +959,16 @@ class YolosLoss(nn.Module):
"""
if "logits" not in outputs:
raise KeyError("No logits were found in the outputs")
src_logits = outputs["logits"]
source_logits = outputs["logits"]
idx = self._get_src_permutation_idx(indices)
idx = self._get_source_permutation_idx(indices)
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
)
target_classes[idx] = target_classes_o
loss_ce = nn.functional.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {"loss_ce": loss_ce}
return losses
......@@ -998,17 +998,17 @@ class YolosLoss(nn.Module):
"""
if "pred_boxes" not in outputs:
raise KeyError("No predicted boxes found in outputs")
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx]
idx = self._get_source_permutation_idx(indices)
source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none")
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(
generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
)
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
......@@ -1022,41 +1022,41 @@ class YolosLoss(nn.Module):
if "pred_masks" not in outputs:
raise KeyError("No predicted masks found in outputs")
src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)
src_masks = outputs["pred_masks"]
src_masks = src_masks[src_idx]
source_idx = self._get_source_permutation_idx(indices)
target_idx = self._get_target_permutation_idx(indices)
source_masks = outputs["pred_masks"]
source_masks = source_masks[source_idx]
masks = [t["masks"] for t in targets]
# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks)
target_masks = target_masks[tgt_idx]
target_masks = target_masks.to(source_masks)
target_masks = target_masks[target_idx]
# upsample predictions to the target size
src_masks = nn.functional.interpolate(
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
source_masks = nn.functional.interpolate(
source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
)
src_masks = src_masks[:, 0].flatten(1)
source_masks = source_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape)
target_masks = target_masks.view(source_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
"loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
"loss_dice": dice_loss(source_masks, target_masks, num_boxes),
}
return losses
def _get_src_permutation_idx(self, indices):
def _get_source_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, source_idx
def _get_tgt_permutation_idx(self, indices):
def _get_target_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
target_idx = torch.cat([target for (_, target) in indices])
return batch_idx, target_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes):
loss_map = {
......@@ -1077,7 +1077,7 @@ class YolosLoss(nn.Module):
outputs (`dict`, *optional*):
Dictionary of tensors, see the output specification of the model for the format.
targets (`List[dict]`, *optional*):
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
......
......@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch CONDITIONAL_DETR model. """
""" Testing suite for the PyTorch Conditional DETR model. """
import inspect
......@@ -213,19 +213,19 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_conditional_detr_object_detection_head_model(*config_and_inputs)
@unittest.skip(reason="CONDITIONAL_DETR does not use inputs_embeds")
@unittest.skip(reason="Conditional DETR does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="CONDITIONAL_DETR does not have a get_input_embeddings method")
@unittest.skip(reason="Conditional DETR does not have a get_input_embeddings method")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="CONDITIONAL_DETR is not a generative model")
@unittest.skip(reason="Conditional DETR is not a generative model")
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="CONDITIONAL_DETR does not use token embeddings")
@unittest.skip(reason="Conditional DETR does not use token embeddings")
def test_resize_tokens_embeddings(self):
pass
......@@ -474,7 +474,7 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase):
expected_shape = torch.Size((1, 300, 256))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor(
[[0.0616, -0.5146, -0.4032], [-0.7629, -0.4934, -1.7153], [-0.4768, -0.6403, -0.7826]]
[[0.4222, 0.7471, 0.8760], [0.6395, -0.2729, 0.7127], [-0.3090, 0.7642, 0.9529]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
......@@ -495,48 +495,13 @@ class ConditionalDetrModelIntegrationTests(unittest.TestCase):
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape_logits)
expected_slice_logits = torch.tensor(
[[-19.1194, -0.0893, -11.0154], [-17.3640, -1.8035, -14.0219], [-20.0461, -0.5837, -11.1060]]
[[-10.4372, -5.7558, -8.6764], [-10.5410, -5.8704, -8.0590], [-10.6827, -6.3469, -8.3923]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4))
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
expected_slice_boxes = torch.tensor(
[[0.4433, 0.5302, 0.8853], [0.5494, 0.2517, 0.0529], [0.4998, 0.5360, 0.9956]]
[[0.7733, 0.6576, 0.4496], [0.5171, 0.1184, 0.9094], [0.8846, 0.5647, 0.2486]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
def test_inference_panoptic_segmentation_head(self):
model = ConditionalDetrForSegmentation.from_pretrained("microsoft/conditional-detr-resnet-50-panoptic").to(
torch_device
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
encoding = feature_extractor(images=image, return_tensors="pt").to(torch_device)
pixel_values = encoding["pixel_values"].to(torch_device)
pixel_mask = encoding["pixel_mask"].to(torch_device)
with torch.no_grad():
outputs = model(pixel_values, pixel_mask)
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape_logits)
expected_slice_logits = torch.tensor(
[[-18.1565, -1.7568, -13.5029], [-16.8888, -1.4138, -14.1028], [-17.5709, -2.5080, -11.8654]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4))
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
expected_slice_boxes = torch.tensor(
[[0.5344, 0.1789, 0.9285], [0.4420, 0.0572, 0.0875], [0.6630, 0.6887, 0.1017]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
expected_shape_masks = torch.Size((1, model.config.num_queries, 200, 267))
self.assertEqual(outputs.pred_masks.shape, expected_shape_masks)
expected_slice_masks = torch.tensor(
[[-7.7558, -10.8788, -11.9797], [-11.8881, -16.4329, -17.7451], [-14.7316, -19.7383, -20.3004]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_masks[0, 0, :3, :3], expected_slice_masks, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment