"...resnet50_tensorflow.git" did not exist on "c02980688ee7ac9de16cdd4c9101ead1c84d03f6"
Unverified Commit d3b1adf5 authored by Vitali Petsiuk's avatar Vitali Petsiuk Committed by GitHub
Browse files

Removes duplicate computations in DETR post processing (#21592)

* Remove redundant computations, comb variable names

* Fix scores to cur_scores
parent d4ba6e1a
...@@ -1311,6 +1311,7 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -1311,6 +1311,7 @@ class DetrImageProcessor(BaseImageProcessor):
FutureWarning, FutureWarning,
) )
out_logits, raw_masks = outputs.logits, outputs.pred_masks out_logits, raw_masks = outputs.logits, outputs.pred_masks
empty_label = out_logits.shape[-1] - 1
preds = [] preds = []
def to_tuple(tup): def to_tuple(tup):
...@@ -1320,16 +1321,15 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -1320,16 +1321,15 @@ class DetrImageProcessor(BaseImageProcessor):
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes): for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
# we filter empty queries and detection below threshold # we filter empty queries and detection below threshold
scores, labels = cur_logits.softmax(-1).max(-1) cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold) keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
cur_scores = cur_scores[keep] cur_scores = cur_scores[keep]
cur_classes = cur_classes[keep] cur_labels = cur_labels[keep]
cur_masks = cur_masks[keep] cur_masks = cur_masks[keep]
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1 cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
predictions = {"scores": cur_scores, "labels": cur_classes, "masks": cur_masks} predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks}
preds.append(predictions) preds.append(predictions)
return preds return preds
...@@ -1423,6 +1423,7 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -1423,6 +1423,7 @@ class DetrImageProcessor(BaseImageProcessor):
raise ValueError( raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits and masks" "Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
) )
empty_label = out_logits.shape[-1] - 1
preds = [] preds = []
def to_tuple(tup): def to_tuple(tup):
...@@ -1434,24 +1435,23 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -1434,24 +1435,23 @@ class DetrImageProcessor(BaseImageProcessor):
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
): ):
# we filter empty queries and detection below threshold # we filter empty queries and detection below threshold
scores, labels = cur_logits.softmax(-1).max(-1) cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold) keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
cur_scores = cur_scores[keep] cur_scores = cur_scores[keep]
cur_classes = cur_classes[keep] cur_labels = cur_labels[keep]
cur_masks = cur_masks[keep] cur_masks = cur_masks[keep]
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
cur_boxes = center_to_corners_format(cur_boxes[keep]) cur_boxes = center_to_corners_format(cur_boxes[keep])
h, w = cur_masks.shape[-2:] h, w = cur_masks.shape[-2:]
if len(cur_boxes) != len(cur_classes): if len(cur_boxes) != len(cur_labels):
raise ValueError("Not as many boxes as there are classes") raise ValueError("Not as many boxes as there are classes")
# It may be that we have several predicted masks for the same stuff class. # It may be that we have several predicted masks for the same stuff class.
# In the following, we track the list of masks ids for each stuff class (they are merged later on) # In the following, we track the list of masks ids for each stuff class (they are merged later on)
cur_masks = cur_masks.flatten(1) cur_masks = cur_masks.flatten(1)
stuff_equiv_classes = defaultdict(lambda: []) stuff_equiv_classes = defaultdict(lambda: [])
for k, label in enumerate(cur_classes): for k, label in enumerate(cur_labels):
if not is_thing_map[label.item()]: if not is_thing_map[label.item()]:
stuff_equiv_classes[label.item()].append(k) stuff_equiv_classes[label.item()].append(k)
...@@ -1491,28 +1491,28 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -1491,28 +1491,28 @@ class DetrImageProcessor(BaseImageProcessor):
return area, seg_img return area, seg_img
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
if cur_classes.numel() > 0: if cur_labels.numel() > 0:
# We know filter empty masks as long as we find some # We know filter empty masks as long as we find some
while True: while True:
filtered_small = torch.as_tensor( filtered_small = torch.as_tensor(
[area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device [area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device
) )
if filtered_small.any().item(): if filtered_small.any().item():
cur_scores = cur_scores[~filtered_small] cur_scores = cur_scores[~filtered_small]
cur_classes = cur_classes[~filtered_small] cur_labels = cur_labels[~filtered_small]
cur_masks = cur_masks[~filtered_small] cur_masks = cur_masks[~filtered_small]
area, seg_img = get_ids_area(cur_masks, cur_scores) area, seg_img = get_ids_area(cur_masks, cur_scores)
else: else:
break break
else: else:
cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device)
segments_info = [] segments_info = []
for i, a in enumerate(area): for i, a in enumerate(area):
cat = cur_classes[i].item() cat = cur_labels[i].item()
segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a}) segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
del cur_classes del cur_labels
with io.BytesIO() as out: with io.BytesIO() as out:
seg_img.save(out, format="PNG") seg_img.save(out, format="PNG")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment