Commit 48647f79 authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'master' into single-stage

parents e84af8ce 3f5df4f0
...@@ -16,14 +16,14 @@ class CocoDataset(CustomDataset): ...@@ -16,14 +16,14 @@ class CocoDataset(CustomDataset):
self.img_ids = self.coco.getImgIds() self.img_ids = self.coco.getImgIds()
img_infos = [] img_infos = []
for i in self.img_ids: for i in self.img_ids:
info = self.coco.loadImgs(i)[0] info = self.coco.loadImgs([i])[0]
info['filename'] = info['file_name'] info['filename'] = info['file_name']
img_infos.append(info) img_infos.append(info)
return img_infos return img_infos
def get_ann_info(self, idx): def get_ann_info(self, idx):
img_id = self.img_infos[idx]['id'] img_id = self.img_infos[idx]['id']
ann_ids = self.coco.getAnnIds(imgIds=img_id) ann_ids = self.coco.getAnnIds(imgIds=[img_id])
ann_info = self.coco.loadAnns(ann_ids) ann_info = self.coco.loadAnns(ann_ids)
return self._parse_ann_info(ann_info) return self._parse_ann_info(ann_info)
......
...@@ -25,13 +25,13 @@ def build_dataloader(dataset, ...@@ -25,13 +25,13 @@ def build_dataloader(dataset,
batch_size = imgs_per_gpu batch_size = imgs_per_gpu
num_workers = workers_per_gpu num_workers = workers_per_gpu
else: else:
sampler = GroupSampler(dataset, imgs_per_gpu) if not kwargs.get('shuffle', True):
sampler = None
else:
sampler = GroupSampler(dataset, imgs_per_gpu)
batch_size = num_gpus * imgs_per_gpu batch_size = num_gpus * imgs_per_gpu
num_workers = num_gpus * workers_per_gpu num_workers = num_gpus * workers_per_gpu
if not kwargs.get('shuffle', True):
sampler = None
data_loader = DataLoader( data_loader = DataLoader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
......
...@@ -11,7 +11,7 @@ class ConvFCBBoxHead(BBoxHead): ...@@ -11,7 +11,7 @@ class ConvFCBBoxHead(BBoxHead):
/-> cls convs -> cls fcs -> cls /-> cls convs -> cls fcs -> cls
shared convs -> shared fcs shared convs -> shared fcs
\-> reg convs -> reg fcs -> reg \-> reg convs -> reg fcs -> reg
""" """ # noqa: W605
def __init__(self, def __init__(self,
num_shared_convs=0, num_shared_convs=0,
......
...@@ -65,6 +65,9 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -65,6 +65,9 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
if self.with_bbox: if self.with_bbox:
self.bbox_roi_extractor.init_weights() self.bbox_roi_extractor.init_weights()
self.bbox_head.init_weights() self.bbox_head.init_weights()
if self.with_mask:
self.mask_roi_extractor.init_weights()
self.mask_head.init_weights()
def extract_feat(self, img): def extract_feat(self, img):
x = self.backbone(img) x = self.backbone(img)
......
...@@ -30,7 +30,7 @@ class RPNHead(nn.Module): ...@@ -30,7 +30,7 @@ class RPNHead(nn.Module):
target_stds (Iterable): Std values of regression targets. target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification. use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default) (softmax by default)
""" """ # noqa: W605
def __init__(self, def __init__(self,
in_channels, in_channels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment