"docs/vscode:/vscode.git/clone" did not exist on "3cb437f5e988ea9d8d8368e2a6453d939d9fbb4b"
Commit f0ecb9d5 authored by yhcao6's avatar yhcao6
Browse files

update center of base anchor to be half of stride

parent 2de84ef8
......@@ -51,7 +51,7 @@ data = dict(
workers_per_gpu=2,
train=dict(
type='RepeatDataset',
times=20,
times=10,
dataset=dict(
type=dataset_type,
ann_file=[
......@@ -113,7 +113,7 @@ lr_config = dict(
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
step=[16, 20])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
......@@ -124,7 +124,7 @@ log_config = dict(
])
# yapf:enable
# runtime settings
total_epochs = 12
total_epochs = 24
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/ssd300_voc'
......
......@@ -51,7 +51,7 @@ data = dict(
workers_per_gpu=2,
train=dict(
type='RepeatDataset',
times=20,
times=10,
dataset=dict(
type=dataset_type,
ann_file=[
......@@ -113,7 +113,7 @@ lr_config = dict(
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
step=[16, 20])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
......@@ -124,7 +124,7 @@ log_config = dict(
])
# yapf:enable
# runtime settings
total_epochs = 12
total_epochs = 24
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/ssd512_voc'
......
......@@ -51,7 +51,7 @@ data = dict(
workers_per_gpu=3,
train=dict(
type='RepeatDataset',
times=10,
times=5,
dataset=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
......@@ -110,7 +110,7 @@ lr_config = dict(
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
step=[16, 22])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
......@@ -121,7 +121,7 @@ log_config = dict(
])
# yapf:enable
# runtime settings
total_epochs = 12
total_epochs = 24
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/ssd300_coco'
......
......@@ -51,7 +51,7 @@ data = dict(
workers_per_gpu=3,
train=dict(
type='RepeatDataset',
times=10,
times=5,
dataset=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
......@@ -110,7 +110,7 @@ lr_config = dict(
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
step=[16, 22])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
......@@ -121,7 +121,7 @@ log_config = dict(
])
# yapf:enable
# runtime settings
total_epochs = 12
total_epochs = 24
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/ssd512_coco'
......
......@@ -3,11 +3,12 @@ import torch
class AnchorGenerator(object):
def __init__(self, base_size, scales, ratios, scale_major=True):
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
self.base_size = base_size
self.scales = torch.Tensor(scales)
self.ratios = torch.Tensor(ratios)
self.scale_major = scale_major
self.ctr = ctr
self.base_anchors = self.gen_base_anchors()
@property
......@@ -15,13 +16,13 @@ class AnchorGenerator(object):
return self.base_anchors.size(0)
def gen_base_anchors(self):
base_anchor = torch.Tensor(
[0, 0, self.base_size - 1, self.base_size - 1])
w = base_anchor[2] - base_anchor[0] + 1
h = base_anchor[3] - base_anchor[1] + 1
x_ctr = base_anchor[0] + 0.5 * (w - 1)
y_ctr = base_anchor[1] + 0.5 * (h - 1)
w = self.base_size
h = self.base_size
if self.ctr is None:
x_ctr = 0.5 * (w - 1)
y_ctr = 0.5 * (h - 1)
else:
x_ctr, y_ctr = self.ctr
h_ratios = torch.sqrt(self.ratios)
w_ratios = 1 / h_ratios
......
......@@ -72,12 +72,14 @@ class SSDHead(nn.Module):
self.anchor_strides = anchor_strides
for k in range(len(anchor_strides)):
base_size = min_sizes[k]
stride = anchor_strides[k]
ctr = ((stride - 1) / 2., (stride - 1) / 2.)
scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
ratios = [1.]
for r in anchor_ratios[k]:
ratios += [1 / r, r] # 4 or 6 ratio
anchor_generator = AnchorGenerator(
base_size, scales, ratios, scale_major=False)
base_size, scales, ratios, scale_major=False, ctr=ctr)
indices = list(range(len(ratios)))
indices.insert(1, len(indices))
anchor_generator.base_anchors = torch.index_select(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment