"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9479052dded67788932ebce370e53d69412ea7d1"
Commit aea87f6c authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

fix two-stage DF-DETR

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/106

# 2-stage DF-DETR

DF-DETR supports 2-stage detection. In the 1st stage, we detect class-agnostic boxes using the feature pyramid (a.k.a. `memory` in the code) computed by the encoder.

Current implementation has a few flaws
- In `setcriterion.py`, when computing loss for encoder 1st stage predictions, `num_boxes` should be reduced across gpus and also clamped to be positive integer to avoid divide-by-zero bug. Current implementation will lead to divide-by-zero NaN issue when `num_boxes` is zero (e.g. no box annotation in the cropped input image).
- In `gen_encoder_output_proposals()`, it manually fill in `float("inf")` at invalid spatial positions outside of actual image size. However, it is not guaranteed that those positions won't be selected as top-scored positions.  `float("inf")` can easily cause affected parameters to be updated to NaN value.
- `class_embed` for encoder should has 1 channel rather than num_class channels because we only need to predict the probability of being a foreground box.

This diff fixes the issues above.

# Gradient blocking in decoder

Currently, gradient of reference point is blocked at each decoding layer to improve numerical stability during training.
In this diff, add an option `MODEL.DETR.DECODER_BLOCK_GRAD`. When False, we do NOT block the gradient. Empirically, we find this leads to better box AP.

Reviewed By: zhanghang1989

Differential Revision: D30325396

fbshipit-source-id: 7d7add1e05888adda6e46cc6886117170daa22d4
parent 7677f3ec
...@@ -33,7 +33,7 @@ def add_detr_config(cfg): ...@@ -33,7 +33,7 @@ def add_detr_config(cfg):
cfg.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1 cfg.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1
cfg.MODEL.DETR.WITH_BOX_REFINE = False cfg.MODEL.DETR.WITH_BOX_REFINE = False
cfg.MODEL.DETR.TWO_STAGE = False cfg.MODEL.DETR.TWO_STAGE = False
cfg.MODEL.DETR.DECODER_BLOCK_GRAD = True
# TRANSFORMER # TRANSFORMER
cfg.MODEL.DETR.NHEADS = 8 cfg.MODEL.DETR.NHEADS = 8
cfg.MODEL.DETR.DROPOUT = 0.1 cfg.MODEL.DETR.DROPOUT = 0.1
......
...@@ -72,6 +72,7 @@ class DeformableDETR(nn.Module): ...@@ -72,6 +72,7 @@ class DeformableDETR(nn.Module):
self.num_queries = num_queries self.num_queries = num_queries
self.transformer = transformer self.transformer = transformer
hidden_dim = transformer.d_model hidden_dim = transformer.d_model
# We will use sigmoid activation and focal loss
self.class_embed = nn.Linear(hidden_dim, num_classes) self.class_embed = nn.Linear(hidden_dim, num_classes)
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.num_feature_levels = num_feature_levels self.num_feature_levels = num_feature_levels
...@@ -116,19 +117,12 @@ class DeformableDETR(nn.Module): ...@@ -116,19 +117,12 @@ class DeformableDETR(nn.Module):
prior_prob = 0.01 prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob) bias_value = -math.log((1 - prior_prob) / prior_prob)
self.class_embed.bias.data = torch.ones(num_classes) * bias_value self.class_embed.bias.data = torch.ones(num_classes) * bias_value
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
for proj in self.input_proj: for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0) nn.init.constant_(proj[0].bias, 0)
# if two-stage, the last class_embed and bbox_embed is for region proposal generation num_pred = transformer.decoder.num_layers
num_pred = (
(transformer.decoder.num_layers + 1)
if two_stage
else transformer.decoder.num_layers
)
if with_box_refine: if with_box_refine:
self.class_embed = _get_clones(self.class_embed, num_pred) self.class_embed = _get_clones(self.class_embed, num_pred)
self.bbox_embed = _get_clones(self.bbox_embed, num_pred) self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
...@@ -145,10 +139,16 @@ class DeformableDETR(nn.Module): ...@@ -145,10 +139,16 @@ class DeformableDETR(nn.Module):
self.transformer.decoder.bbox_embed = None self.transformer.decoder.bbox_embed = None
if two_stage: if two_stage:
# hack implementation for two-stage # hack implementation for two-stage
self.transformer.decoder.class_embed = self.class_embed # We only predict foreground/background at the output of encoder
class_embed = nn.Linear(hidden_dim, 1)
class_embed.bias.data = torch.ones(1) * bias_value
self.transformer.encoder.class_embed = class_embed
for box_embed in self.bbox_embed: for box_embed in self.bbox_embed:
nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
self.transformer.encoder.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
def forward(self, samples: NestedTensor): def forward(self, samples: NestedTensor):
"""The forward expects a NestedTensor, which consists of: """The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
......
...@@ -16,6 +16,11 @@ from torch import nn ...@@ -16,6 +16,11 @@ from torch import nn
from torch.nn.init import xavier_uniform_, constant_, normal_ from torch.nn.init import xavier_uniform_, constant_, normal_
from ..modules import MSDeformAttn from ..modules import MSDeformAttn
from ..util.misc import inverse_sigmoid
# we do not use float("-inf") to avoid potential NaN during training
NEG_INF = -10000.0
class DeformableTransformer(nn.Module): class DeformableTransformer(nn.Module):
...@@ -34,6 +39,7 @@ class DeformableTransformer(nn.Module): ...@@ -34,6 +39,7 @@ class DeformableTransformer(nn.Module):
enc_n_points=4, enc_n_points=4,
two_stage=False, two_stage=False,
two_stage_num_proposals=300, two_stage_num_proposals=300,
decoder_block_grad=True,
): ):
super().__init__() super().__init__()
...@@ -63,7 +69,10 @@ class DeformableTransformer(nn.Module): ...@@ -63,7 +69,10 @@ class DeformableTransformer(nn.Module):
dec_n_points, dec_n_points,
) )
self.decoder = DeformableTransformerDecoder( self.decoder = DeformableTransformerDecoder(
decoder_layer, num_decoder_layers, return_intermediate_dec decoder_layer,
num_decoder_layers,
return_intermediate_dec,
decoder_block_grad,
) )
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
...@@ -81,13 +90,13 @@ class DeformableTransformer(nn.Module): ...@@ -81,13 +90,13 @@ class DeformableTransformer(nn.Module):
def _reset_parameters(self): def _reset_parameters(self):
for p in self.parameters(): for p in self.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) xavier_uniform_(p)
for m in self.modules(): for m in self.modules():
if isinstance(m, MSDeformAttn): if isinstance(m, MSDeformAttn):
m._reset_parameters() m._reset_parameters()
if not self.two_stage: if not self.two_stage:
xavier_uniform_(self.reference_points.weight.data, gain=1.0) xavier_uniform_(self.reference_points.weight, gain=1.0)
constant_(self.reference_points.bias.data, 0.0) constant_(self.reference_points.bias, 0.0)
normal_(self.level_embed) normal_(self.level_embed)
def get_proposal_pos_embed(self, proposals): def get_proposal_pos_embed(self, proposals):
...@@ -163,14 +172,8 @@ class DeformableTransformer(nn.Module): ...@@ -163,14 +172,8 @@ class DeformableTransformer(nn.Module):
output_proposals_valid = ( output_proposals_valid = (
(output_proposals > 0.01) & (output_proposals < 0.99) (output_proposals > 0.01) & (output_proposals < 0.99)
).all(-1, keepdim=True) ).all(-1, keepdim=True)
# inverse sigmoid
output_proposals = torch.log(output_proposals / (1 - output_proposals)) output_proposals = inverse_sigmoid(output_proposals)
output_proposals = output_proposals.masked_fill(
memory_padding_mask.unsqueeze(-1), float("inf")
)
output_proposals = output_proposals.masked_fill(
~output_proposals_valid, float("inf")
)
# memory: shape (bs, K, C) # memory: shape (bs, K, C)
output_memory = memory output_memory = memory
# memory_padding_mask: shape (bs, K) # memory_padding_mask: shape (bs, K)
...@@ -179,7 +182,7 @@ class DeformableTransformer(nn.Module): ...@@ -179,7 +182,7 @@ class DeformableTransformer(nn.Module):
) )
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory)) output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals return output_memory, output_proposals, output_proposals_valid
def get_valid_ratio(self, mask): def get_valid_ratio(self, mask):
_, H, W = mask.shape _, H, W = mask.shape
...@@ -226,7 +229,7 @@ class DeformableTransformer(nn.Module): ...@@ -226,7 +229,7 @@ class DeformableTransformer(nn.Module):
src_flatten = torch.cat(src_flatten, 1) src_flatten = torch.cat(src_flatten, 1)
# mask_flatten shape: (bs, K) # mask_flatten shape: (bs, K)
mask_flatten = torch.cat(mask_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1)
# mask_flatten shape: (bs, K, c) # lvl_pos_embed_flatten shape: (bs, K, c)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
# spatial_shapes shape: (num_levels, 2) # spatial_shapes shape: (num_levels, 2)
spatial_shapes = torch.as_tensor( spatial_shapes = torch.as_tensor(
...@@ -238,7 +241,6 @@ class DeformableTransformer(nn.Module): ...@@ -238,7 +241,6 @@ class DeformableTransformer(nn.Module):
) )
# valid_ratios shape: (bs, num_levels, 2) # valid_ratios shape: (bs, num_levels, 2)
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# encoder # encoder
# memory shape (bs, K, C) where K = \sum_l H_l * w_l # memory shape (bs, K, C) where K = \sum_l H_l * w_l
memory = self.encoder( memory = self.encoder(
...@@ -255,29 +257,34 @@ class DeformableTransformer(nn.Module): ...@@ -255,29 +257,34 @@ class DeformableTransformer(nn.Module):
if self.two_stage: if self.two_stage:
# output_memory shape (bs, K, C) # output_memory shape (bs, K, C)
# output_proposals shape (bs, K, 4) # output_proposals shape (bs, K, 4)
output_memory, output_proposals = self.gen_encoder_output_proposals( # output_proposals_valid shape (bs, K, 1)
memory, mask_flatten, spatial_shapes (
) output_memory,
output_proposals,
output_proposals_valid,
) = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
# hack implementation for two-stage Deformable DETR # hack implementation for two-stage Deformable DETR
# shape (bs, K, num_classes) # shape (bs, K, 1)
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers]( enc_outputs_class = self.encoder.class_embed(output_memory)
output_memory # fill in -inf foreground logit at invalid positions so that we will never pick
) # top-scored proposals at those positions
enc_outputs_class.masked_fill(mask_flatten.unsqueeze(-1), NEG_INF)
enc_outputs_class.masked_fill(~output_proposals_valid, NEG_INF)
# shape (bs, K, 4) # shape (bs, K, 4)
enc_outputs_coord_unact = ( enc_outputs_coord_unact = (
self.decoder.bbox_embed[self.decoder.num_layers](output_memory) self.encoder.bbox_embed(output_memory) + output_proposals
+ output_proposals
) )
topk = self.two_stage_num_proposals topk = self.two_stage_num_proposals
# topk_proposals: indices of top items. Shape (bs, top_k) # topk_proposals: indices of top items. Shape (bs, top_k)
# TODO (zyan3): use a standalone class_embed layer with 2 output channels?
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
# topk_coords_unact shape (bs, top_k, 4) # topk_coords_unact shape (bs, top_k, 4)
topk_coords_unact = torch.gather( topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
) )
topk_coords_unact = topk_coords_unact.detach() topk_coords_unact = topk_coords_unact.detach()
init_reference_out = reference_points_unact = topk_coords_unact
init_reference_out = topk_coords_unact
# shape (bs, top_k, C=512) # shape (bs, top_k, C=512)
pos_trans_out = self.pos_trans_norm( pos_trans_out = self.pos_trans_norm(
self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)) self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))
...@@ -292,17 +299,15 @@ class DeformableTransformer(nn.Module): ...@@ -292,17 +299,15 @@ class DeformableTransformer(nn.Module):
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
# tgt shape: (batch_size, num_queries, c) # tgt shape: (batch_size, num_queries, c)
tgt = tgt.unsqueeze(0).expand(bs, -1, -1) tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
# init_reference_out shape: (batch_size, num_queries, 2), value \in (0, 1) # init_reference_out shape: (batch_size, num_queries, 2)
init_reference_out = self.reference_points(query_embed) init_reference_out = self.reference_points(query_embed)
# block gradient backpropagation here to stabilize optimization
reference_points_unact = init_reference_out.detach()
# decoder # decoder
# hs shape: (num_layers, batch_size, num_queries, c) # hs shape: (num_layers, batch_size, num_queries, c)
# inter_references shape: (num_layers, batch_size, num_queries, num_levels, 2) # inter_references shape: (num_layers, batch_size, num_queries, num_levels, 2)
hs, inter_references = self.decoder( hs, inter_references = self.decoder(
tgt, tgt,
reference_points_unact, init_reference_out,
memory, memory,
spatial_shapes, spatial_shapes,
level_start_index, level_start_index,
...@@ -546,14 +551,16 @@ class DeformableTransformerDecoderLayer(nn.Module): ...@@ -546,14 +551,16 @@ class DeformableTransformerDecoderLayer(nn.Module):
class DeformableTransformerDecoder(nn.Module): class DeformableTransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, return_intermediate=False): def __init__(
self, decoder_layer, num_layers, return_intermediate=False, block_grad=True
):
super().__init__() super().__init__()
self.layers = _get_clones(decoder_layer, num_layers) self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers self.num_layers = num_layers
self.return_intermediate = return_intermediate self.return_intermediate = return_intermediate
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
self.bbox_embed = None self.bbox_embed = None
self.class_embed = None self.block_grad = block_grad
def forward( def forward(
self, self,
...@@ -619,7 +626,9 @@ class DeformableTransformerDecoder(nn.Module): ...@@ -619,7 +626,9 @@ class DeformableTransformerDecoder(nn.Module):
tmp[..., :2] + reference_points_unact tmp[..., :2] + reference_points_unact
) )
# block gradient backpropagation here to stabilize optimization # block gradient backpropagation here to stabilize optimization
new_reference_points_unact = new_reference_points_unact.detach() if self.block_grad:
new_reference_points_unact = new_reference_points_unact.detach()
reference_points_unact = new_reference_points_unact reference_points_unact = new_reference_points_unact
else: else:
new_reference_points_unact = reference_points_unact new_reference_points_unact = reference_points_unact
......
...@@ -15,6 +15,16 @@ from ..util.misc import ( ...@@ -15,6 +15,16 @@ from ..util.misc import (
from .segmentation import dice_loss, sigmoid_focal_loss from .segmentation import dice_loss, sigmoid_focal_loss
def _reduce_num_boxes(targets, device):
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
return num_boxes
class SetCriterion(nn.Module): class SetCriterion(nn.Module):
"""This class computes the loss for DETR. """This class computes the loss for DETR.
The process happens in two steps: The process happens in two steps:
...@@ -78,6 +88,50 @@ class SetCriterion(nn.Module): ...@@ -78,6 +88,50 @@ class SetCriterion(nn.Module):
losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0] losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses return losses
def forground_background_loss_labels(
self, outputs, targets, indices, num_boxes, log=True
):
assert "pred_logits" in outputs
# shape (batch_size, num_queries, 1)
src_logits = outputs["pred_logits"]
batch_size, num_queries = src_logits.shape[:2]
assert src_logits.shape[2] == 1, f"expect 1 class {src_logits.shape[2]}"
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat(
[t["labels"][J] for t, (_, J) in zip(targets, indices)]
)
target_classes = torch.full(
src_logits.shape[:2],
1,
dtype=torch.int64,
device=src_logits.device,
)
target_classes[idx] = target_classes_o
target_classes_onehot = torch.zeros(
[src_logits.shape[0], src_logits.shape[1], 2],
dtype=src_logits.dtype,
layout=src_logits.layout,
device=src_logits.device,
)
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
target_classes_onehot = target_classes_onehot[:, :, :-1]
loss_ce = (
sigmoid_focal_loss(
src_logits,
target_classes_onehot,
num_boxes,
alpha=self.focal_alpha,
gamma=2,
)
* src_logits.shape[1]
)
return {"loss_ce": loss_ce}
@torch.no_grad() @torch.no_grad()
def loss_cardinality(self, outputs, targets, indices, num_boxes): def loss_cardinality(self, outputs, targets, indices, num_boxes):
"""Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes """Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
...@@ -181,21 +235,24 @@ class SetCriterion(nn.Module): ...@@ -181,21 +235,24 @@ class SetCriterion(nn.Module):
assert loss in loss_map, f"do you really want to compute {loss} loss?" assert loss in loss_map, f"do you really want to compute {loss} loss?"
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
def get_foreground_background_loss(
self, loss, outputs, targets, indices, num_boxes, **kwargs
):
loss_map = {
"labels": self.forground_background_loss_labels,
"cardinality": self.loss_cardinality,
"boxes": self.loss_boxes,
}
assert loss in loss_map, f"do you really want to compute {loss} loss?"
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
def _forward(self, outputs, outputs_without_aux, targets): def _forward(self, outputs, outputs_without_aux, targets):
# Retrieve the matching between the outputs of the last layer and the targets # Retrieve the matching between the outputs of the last layer and the targets
# A list where each item is [row_indices, col_indices] # A list where each item is [row_indices, col_indices]
indices = self.matcher(outputs_without_aux, targets) indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = _reduce_num_boxes(targets, next(iter(outputs.values())).device)
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor(
[num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
# Compute all the requested losses # Compute all the requested losses
losses = {} losses = {}
for loss in self.losses: for loss in self.losses:
...@@ -257,6 +314,7 @@ class FocalLossSetCriterion(SetCriterion): ...@@ -257,6 +314,7 @@ class FocalLossSetCriterion(SetCriterion):
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
""" """
assert "pred_logits" in outputs assert "pred_logits" in outputs
# shape (batch_size, num_queries, num_classes)
src_logits = outputs["pred_logits"] src_logits = outputs["pred_logits"]
idx = self._get_src_permutation_idx(indices) idx = self._get_src_permutation_idx(indices)
...@@ -313,8 +371,7 @@ class FocalLossSetCriterion(SetCriterion): ...@@ -313,8 +371,7 @@ class FocalLossSetCriterion(SetCriterion):
losses = self._forward(outputs, outputs_without_aux, targets) losses = self._forward(outputs, outputs_without_aux, targets)
if "enc_outputs" in outputs: if "enc_outputs" in outputs:
num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = _reduce_num_boxes(targets, next(iter(outputs.values())).device)
enc_outputs = outputs["enc_outputs"] enc_outputs = outputs["enc_outputs"]
bin_targets = copy.deepcopy(targets) bin_targets = copy.deepcopy(targets)
for bt in bin_targets: for bt in bin_targets:
...@@ -328,7 +385,7 @@ class FocalLossSetCriterion(SetCriterion): ...@@ -328,7 +385,7 @@ class FocalLossSetCriterion(SetCriterion):
if loss == "labels": if loss == "labels":
# Logging is enabled only for the last layer # Logging is enabled only for the last layer
kwargs["log"] = False kwargs["log"] = False
l_dict = self.get_loss( l_dict = self.get_foreground_background_loss(
loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs
) )
l_dict = {k + "_enc": v for k, v in l_dict.items()} l_dict = {k + "_enc": v for k, v in l_dict.items()}
......
...@@ -50,8 +50,8 @@ def generalized_box_iou(boxes1, boxes2): ...@@ -50,8 +50,8 @@ def generalized_box_iou(boxes1, boxes2):
""" """
# degenerate boxes gives inf / nan results # degenerate boxes gives inf / nan results
# so do an early check # so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all(), f"incorrect boxes, boxes1 {boxes1}" assert (boxes1[:, 2:] > boxes1[:, :2]).all(), f"incorrect boxes, boxes1 {boxes1}"
assert (boxes2[:, 2:] >= boxes2[:, :2]).all(), f"incorrect boxes, boxes1 {boxes2}" assert (boxes2[:, 2:] > boxes2[:, :2]).all(), f"incorrect boxes, boxes2 {boxes2}"
iou, union = box_iou(boxes1, boxes2) iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment