Commit 5c03d593 authored by TengQi Ye's avatar TengQi Ye Committed by Francisco Massa
Browse files

Fix rpn memory leak and dataType errors. (#1657)

parent 333af7aa
......@@ -78,8 +78,8 @@ class BalancedPositiveNegativeSampler(object):
matched_idxs_per_image, dtype=torch.uint8
)
pos_idx_per_image_mask[pos_idx_per_image] = torch.tensor(1)
neg_idx_per_image_mask[neg_idx_per_image] = torch.tensor(1)
pos_idx_per_image_mask[pos_idx_per_image] = torch.tensor(1, dtype=torch.uint8)
neg_idx_per_image_mask[neg_idx_per_image] = torch.tensor(1, dtype=torch.uint8)
pos_idx.append(pos_idx_per_image_mask)
neg_idx.append(neg_idx_per_image_mask)
......
......@@ -163,6 +163,8 @@ class AnchorGenerator(nn.Module):
anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks.
self._cache.clear()
return anchors
......@@ -333,11 +335,11 @@ class RegionProposalNetwork(torch.nn.Module):
# Background (negative examples)
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = torch.tensor(0)
labels_per_image[bg_indices] = torch.tensor(0.0)
# discard indices that are between thresholds
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = torch.tensor(-1)
labels_per_image[inds_to_discard] = torch.tensor(-1.0)
labels.append(labels_per_image)
matched_gt_boxes.append(matched_gt_boxes_per_image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment