Commit d9f4f254 authored by Jon Crall's avatar Jon Crall Committed by Kai Chen
Browse files

FIX: Pass device to grid_anchors and valid_flags (#1478)

* Pass device to grid_anchors and valid_flags

* fix yapf formatting
parent f53de2be
......@@ -91,12 +91,13 @@ class AnchorHead(nn.Module):
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas):
def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
device (torch.device | str): device for returned tensors
Returns:
tuple: anchors of each image, valid flags of each image
......@@ -109,7 +110,7 @@ class AnchorHead(nn.Module):
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
featmap_sizes[i], self.anchor_strides[i], device=device)
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
......@@ -124,7 +125,8 @@ class AnchorHead(nn.Module):
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
(feat_h, feat_w), (valid_feat_h, valid_feat_w),
device=device)
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
......@@ -162,8 +164,10 @@ class AnchorHead(nn.Module):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas)
featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(
anchor_list,
......@@ -201,10 +205,12 @@ class AnchorHead(nn.Module):
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
device = cls_scores[0].device
mlvl_anchors = [
self.anchor_generators[i].grid_anchors(cls_scores[i].size()[-2:],
self.anchor_strides[i])
for i in range(num_levels)
self.anchor_generators[i].grid_anchors(
cls_scores[i].size()[-2:],
self.anchor_strides[i],
device=device) for i in range(num_levels)
]
result_list = []
for img_id in range(len(img_metas)):
......
......@@ -210,12 +210,14 @@ class GuidedAnchorHead(AnchorHead):
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_sampled_approxs(self, featmap_sizes, img_metas, cfg):
def get_sampled_approxs(self, featmap_sizes, img_metas, cfg,
device='cuda'):
"""Get sampled approxs and inside flags according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
device (torch.device | str): device for returned tensors
Returns:
tuple: approxes of each image, inside flags of each image
......@@ -228,7 +230,7 @@ class GuidedAnchorHead(AnchorHead):
multi_level_approxs = []
for i in range(num_levels):
approxs = self.approx_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
featmap_sizes[i], self.anchor_strides[i], device=device)
multi_level_approxs.append(approxs)
approxs_list = [multi_level_approxs for _ in range(num_imgs)]
......@@ -245,7 +247,8 @@ class GuidedAnchorHead(AnchorHead):
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.approx_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
(feat_h, feat_w), (valid_feat_h, valid_feat_w),
device=device)
inside_flags_list = []
for i in range(self.approxs_per_octave):
split_valid_flags = flags[i::self.approxs_per_octave]
......@@ -267,7 +270,8 @@ class GuidedAnchorHead(AnchorHead):
shape_preds,
loc_preds,
img_metas,
use_loc_filter=False):
use_loc_filter=False,
device='cuda'):
"""Get squares according to feature map sizes and guided
anchors.
......@@ -277,6 +281,7 @@ class GuidedAnchorHead(AnchorHead):
loc_preds (list[tensor]): Multi-level location predictions.
img_metas (list[dict]): Image meta info.
use_loc_filter (bool): Use loc filter or not.
device (torch.device | str): device for returned tensors
Returns:
tuple: square approxs of each image, guided anchors of each image,
......@@ -290,7 +295,7 @@ class GuidedAnchorHead(AnchorHead):
multi_level_squares = []
for i in range(num_levels):
squares = self.square_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
featmap_sizes[i], self.anchor_strides[i], device=device)
multi_level_squares.append(squares)
squares_list = [multi_level_squares for _ in range(num_imgs)]
......@@ -404,6 +409,8 @@ class GuidedAnchorHead(AnchorHead):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.approx_generators)
device = cls_scores[0].device
# get loc targets
loc_targets, loc_weights, loc_avg_factor = ga_loc_target(
gt_bboxes,
......@@ -415,10 +422,10 @@ class GuidedAnchorHead(AnchorHead):
# get sampled approxes
approxs_list, inside_flag_list = self.get_sampled_approxs(
featmap_sizes, img_metas, cfg)
featmap_sizes, img_metas, cfg, device=device)
# get squares and guided anchors
squares_list, guided_anchors_list, _ = self.get_anchors(
featmap_sizes, shape_preds, loc_preds, img_metas)
featmap_sizes, shape_preds, loc_preds, img_metas, device=device)
# get shape targets
sampling = False if not hasattr(cfg, 'ga_sampler') else True
......@@ -515,13 +522,15 @@ class GuidedAnchorHead(AnchorHead):
loc_preds)
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
device = cls_scores[0].device
# get guided anchors
_, guided_anchors, loc_masks = self.get_anchors(
featmap_sizes,
shape_preds,
loc_preds,
img_metas,
use_loc_filter=not self.training)
use_loc_filter=not self.training,
device=device)
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment