Commit 5c03d593 authored by TengQi Ye's avatar TengQi Ye Committed by Francisco Massa
Browse files

Fix rpn memory leak and dataType errors. (#1657)

parent 333af7aa
...@@ -78,8 +78,8 @@ class BalancedPositiveNegativeSampler(object): ...@@ -78,8 +78,8 @@ class BalancedPositiveNegativeSampler(object):
matched_idxs_per_image, dtype=torch.uint8 matched_idxs_per_image, dtype=torch.uint8
) )
pos_idx_per_image_mask[pos_idx_per_image] = torch.tensor(1) pos_idx_per_image_mask[pos_idx_per_image] = torch.tensor(1, dtype=torch.uint8)
neg_idx_per_image_mask[neg_idx_per_image] = torch.tensor(1) neg_idx_per_image_mask[neg_idx_per_image] = torch.tensor(1, dtype=torch.uint8)
pos_idx.append(pos_idx_per_image_mask) pos_idx.append(pos_idx_per_image_mask)
neg_idx.append(neg_idx_per_image_mask) neg_idx.append(neg_idx_per_image_mask)
......
...@@ -163,6 +163,8 @@ class AnchorGenerator(nn.Module): ...@@ -163,6 +163,8 @@ class AnchorGenerator(nn.Module):
anchors_in_image.append(anchors_per_feature_map) anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image) anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks.
self._cache.clear()
return anchors return anchors
...@@ -333,11 +335,11 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -333,11 +335,11 @@ class RegionProposalNetwork(torch.nn.Module):
# Background (negative examples) # Background (negative examples)
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = torch.tensor(0) labels_per_image[bg_indices] = torch.tensor(0.0)
# discard indices that are between thresholds # discard indices that are between thresholds
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = torch.tensor(-1) labels_per_image[inds_to_discard] = torch.tensor(-1.0)
labels.append(labels_per_image) labels.append(labels_per_image)
matched_gt_boxes.append(matched_gt_boxes_per_image) matched_gt_boxes.append(matched_gt_boxes_per_image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment