Commit b2b8e216 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

fix a bug at inference and code refactoring

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/97

Major changes
- Fix a bug within `inference()` function
- Refactor code to remove redundant code between `SetCriterion` and `FocalLossSetCriterion`.

Reviewed By: zhanghang1989

Differential Revision: D29481067

fbshipit-source-id: 64788f1ff331177db964eb36d380430799d1d2f2
parent e830629a
...@@ -382,11 +382,9 @@ class Detr(nn.Module): ...@@ -382,11 +382,9 @@ class Detr(nn.Module):
result = Instances(image_size) result = Instances(image_size)
boxes = box_cxcywh_to_xyxy(box_pred_per_image) boxes = box_cxcywh_to_xyxy(box_pred_per_image)
if self.use_focal_loss: if self.use_focal_loss:
boxes = torch.gather( boxes = torch.gather(boxes, 0, topk_boxes[i].unsqueeze(-1).repeat(1, 4))
boxes.unsqueeze(0), 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)
).squeeze()
result.pred_boxes = Boxes(boxes)
result.pred_boxes = Boxes(boxes)
result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0]) result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0])
if self.mask_on: if self.mask_on:
mask = F.interpolate( mask = F.interpolate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment