Commit 2017c81e authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'master' into pytorch-1.0

parents c4408812 6594f862
......@@ -7,19 +7,33 @@ from .sampling_result import SamplingResult
class BaseSampler(metaclass=ABCMeta):
def __init__(self):
def __init__(self,
num,
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.pos_sampler = self
self.neg_sampler = self
@abstractmethod
def _sample_pos(self, assign_result, num_expected):
def _sample_pos(self, assign_result, num_expected, **kwargs):
pass
@abstractmethod
def _sample_neg(self, assign_result, num_expected):
def _sample_neg(self, assign_result, num_expected, **kwargs):
pass
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None):
def sample(self,
assign_result,
bboxes,
gt_bboxes,
gt_labels=None,
**kwargs):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
......@@ -44,8 +58,8 @@ class BaseSampler(metaclass=ABCMeta):
gt_flags = torch.cat([gt_ones, gt_flags])
num_expected_pos = int(self.num * self.pos_fraction)
pos_inds = self.pos_sampler._sample_pos(assign_result,
num_expected_pos)
pos_inds = self.pos_sampler._sample_pos(
assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds = pos_inds.unique()
......@@ -56,8 +70,8 @@ class BaseSampler(metaclass=ABCMeta):
neg_upper_bound = int(self.neg_pos_ub * _pos)
if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound
neg_inds = self.neg_sampler._sample_neg(assign_result,
num_expected_neg)
neg_inds = self.neg_sampler._sample_neg(
assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique()
return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
......
from .random_sampler import RandomSampler
from .base_sampler import BaseSampler
from ..assign_sampling import build_sampler
class CombinedSampler(RandomSampler):
class CombinedSampler(BaseSampler):
def __init__(self, num, pos_fraction, pos_sampler, neg_sampler, **kwargs):
super(CombinedSampler, self).__init__(num, pos_fraction, **kwargs)
default_args = dict(num=num, pos_fraction=pos_fraction)
default_args.update(kwargs)
self.pos_sampler = build_sampler(
pos_sampler, default_args=default_args)
self.neg_sampler = build_sampler(
neg_sampler, default_args=default_args)
def __init__(self, pos_sampler, neg_sampler, **kwargs):
super(CombinedSampler, self).__init__(**kwargs)
self.pos_sampler = build_sampler(pos_sampler, **kwargs)
self.neg_sampler = build_sampler(neg_sampler, **kwargs)
def _sample_pos(self, **kwargs):
raise NotImplementedError
def _sample_neg(self, **kwargs):
raise NotImplementedError
......@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
class InstanceBalancedPosSampler(RandomSampler):
def _sample_pos(self, assign_result, num_expected):
def _sample_pos(self, assign_result, num_expected, **kwargs):
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
......
......@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
self.hard_thr = hard_thr
self.hard_fraction = hard_fraction
def _sample_neg(self, assign_result, num_expected):
def _sample_neg(self, assign_result, num_expected, **kwargs):
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
......
import torch
from .base_sampler import BaseSampler
from ..transforms import bbox2roi
class OHEMSampler(BaseSampler):
def __init__(self,
num,
pos_fraction,
context,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals)
self.bbox_roi_extractor = context.bbox_roi_extractor
self.bbox_head = context.bbox_head
def hard_mining(self, inds, num_expected, bboxes, labels, feats):
with torch.no_grad():
rois = bbox2roi([bboxes])
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=labels,
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduce=False)['loss_cls']
_, topk_loss_inds = loss.topk(num_expected)
return inds[topk_loss_inds]
def _sample_pos(self,
assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
# Sample some hard positive samples
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected:
return pos_inds
else:
return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds],
assign_result.labels[pos_inds], feats)
def _sample_neg(self,
assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
# Sample some hard negative samples
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected:
return neg_inds
else:
return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds],
assign_result.labels[neg_inds], feats)
......@@ -6,16 +6,16 @@ from .sampling_result import SamplingResult
class PseudoSampler(BaseSampler):
def __init__(self):
def __init__(self, **kwargs):
pass
def _sample_pos(self):
def _sample_pos(self, **kwargs):
raise NotImplementedError
def _sample_neg(self):
def _sample_neg(self, **kwargs):
raise NotImplementedError
def sample(self, assign_result, bboxes, gt_bboxes):
def sample(self, assign_result, bboxes, gt_bboxes, **kwargs):
pos_inds = torch.nonzero(
assign_result.gt_inds > 0).squeeze(-1).unique()
neg_inds = torch.nonzero(
......
......@@ -10,12 +10,10 @@ class RandomSampler(BaseSampler):
num,
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True):
super(RandomSampler, self).__init__()
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
add_gt_as_proposals=True,
**kwargs):
super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals)
@staticmethod
def random_choice(gallery, num):
......@@ -34,7 +32,7 @@ class RandomSampler(BaseSampler):
rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
return gallery[rand_inds]
def _sample_pos(self, assign_result, num_expected):
def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Randomly sample some positive samples."""
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
......@@ -44,7 +42,7 @@ class RandomSampler(BaseSampler):
else:
return self.random_choice(pos_inds, num_expected)
def _sample_neg(self, assign_result, num_expected):
def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Randomly sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
......
......@@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes,
imagenet_vid_classes, coco_classes, dataset_aliases,
get_classes)
from .coco_utils import coco_eval, fast_eval_recall, results2json
from .eval_hooks import (DistEvalHook, CocoDistEvalRecallHook,
from .eval_hooks import (DistEvalHook, DistEvalmAPHook, CocoDistEvalRecallHook,
CocoDistEvalmAPHook)
from .mean_ap import average_precision, eval_map, print_map_summary
from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
......@@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
__all__ = [
'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes',
'coco_classes', 'dataset_aliases', 'get_classes', 'coco_eval',
'fast_eval_recall', 'results2json', 'DistEvalHook',
'fast_eval_recall', 'results2json', 'DistEvalHook', 'DistEvalmAPHook',
'CocoDistEvalRecallHook', 'CocoDistEvalmAPHook', 'average_precision',
'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary',
'plot_num_recall', 'plot_iou_recall'
......
......@@ -63,18 +63,18 @@ def imagenet_vid_classes():
def coco_classes():
return [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign',
'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard',
'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'
]
......
......@@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval
from torch.utils.data import Dataset
from .coco_utils import results2json, fast_eval_recall
from .mean_ap import eval_map
from mmdet import datasets
......@@ -102,6 +103,44 @@ class DistEvalHook(Hook):
raise NotImplementedError
class DistEvalmAPHook(DistEvalHook):
def evaluate(self, runner, results):
gt_bboxes = []
gt_labels = []
gt_ignore = [] if self.dataset.with_crowd else None
for i in range(len(self.dataset)):
ann = self.dataset.get_ann_info(i)
bboxes = ann['bboxes']
labels = ann['labels']
if gt_ignore is not None:
ignore = np.concatenate([
np.zeros(bboxes.shape[0], dtype=np.bool),
np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool)
])
gt_ignore.append(ignore)
bboxes = np.vstack([bboxes, ann['bboxes_ignore']])
labels = np.concatenate([labels, ann['labels_ignore']])
gt_bboxes.append(bboxes)
gt_labels.append(labels)
# If the dataset is VOC2007, then use 11 points mAP evaluation.
if hasattr(self.dataset, 'year') and self.dataset.year == 2007:
ds_name = 'voc07'
else:
ds_name = self.dataset.CLASSES
mean_ap, eval_results = eval_map(
results,
gt_bboxes,
gt_labels,
gt_ignore=gt_ignore,
scale_ranges=None,
iou_thr=0.5,
dataset=ds_name,
print_summary=True)
runner.log_buffer.output['mAP'] = mean_ap
runner.log_buffer.ready = True
class CocoDistEvalRecallHook(DistEvalHook):
def __init__(self,
......
......@@ -10,11 +10,15 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
return torch.sum(raw * weight)[None] / avg_factor
def weighted_cross_entropy(pred, label, weight, avg_factor=None):
def weighted_cross_entropy(pred, label, weight, avg_factor=None,
reduce=True):
if avg_factor is None:
avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
raw = F.cross_entropy(pred, label, reduction='none')
return torch.sum(raw * weight)[None] / avg_factor
if reduce:
return torch.sum(raw * weight)[None] / avg_factor
else:
return raw * weight / avg_factor
def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None):
......
from .custom import CustomDataset
from .xml_style import XMLDataset
from .coco import CocoDataset
from .voc import VOCDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset
from .repeat_dataset import RepeatDataset
__all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann',
'get_dataset', 'ConcatDataset', 'RepeatDataset',
'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler',
'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale',
'show_ann', 'get_dataset', 'ConcatDataset', 'RepeatDataset'
]
......@@ -6,6 +6,21 @@ from .custom import CustomDataset
class CocoDataset(CustomDataset):
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush')
def load_annotations(self, ann_file):
self.coco = COCO(ann_file)
self.cat_ids = self.coco.getCatIds()
......
......@@ -3,16 +3,18 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class ConcatDataset(_ConcatDataset):
"""
Same as torch.utils.data.dataset.ConcatDataset, but
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
"""
def __init__(self, datasets):
"""
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
if hasattr(datasets[0], 'flag'):
flags = []
for i in range(0, len(datasets)):
......
......@@ -32,6 +32,8 @@ class CustomDataset(Dataset):
The `ann` field is optional for testing.
"""
CLASSES = None
def __init__(self,
ann_file,
img_prefix,
......@@ -45,6 +47,8 @@ class CustomDataset(Dataset):
with_crowd=True,
with_label=True,
test_mode=False):
# prefix of images path
self.img_prefix = img_prefix
# load annotations (and proposals)
self.img_infos = self.load_annotations(ann_file)
if proposal_file is not None:
......@@ -58,8 +62,6 @@ class CustomDataset(Dataset):
if self.proposals is not None:
self.proposals = [self.proposals[i] for i in valid_inds]
# prefix of images path
self.img_prefix = img_prefix
# (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
self.img_scales = img_scale if isinstance(img_scale,
list) else [img_scale]
......
......@@ -6,12 +6,14 @@ class RepeatDataset(object):
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self.CLASSES = dataset.CLASSES
if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, times)
self._original_length = len(self.dataset)
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx % self._original_length]
return self.dataset[idx % self._ori_len]
def __len__(self):
return self.times * self._original_length
return self.times * self._ori_len
from .xml_style import XMLDataset
class VOCDataset(XMLDataset):
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor')
def __init__(self, **kwargs):
super(VOCDataset, self).__init__(**kwargs)
if 'VOC2007' in self.img_prefix:
self.year = 2007
elif 'VOC2012' in self.img_prefix:
self.year = 2012
else:
raise ValueError('Cannot infer dataset year from img_prefix')
import os.path as osp
import xml.etree.ElementTree as ET
import mmcv
import numpy as np
from .custom import CustomDataset
class XMLDataset(CustomDataset):
def __init__(self, **kwargs):
super(XMLDataset, self).__init__(**kwargs)
self.cat2label = {cat: i + 1 for i, cat in enumerate(self.CLASSES)}
def load_annotations(self, ann_file):
img_infos = []
img_ids = mmcv.list_from_file(ann_file)
for img_id in img_ids:
filename = 'JPEGImages/{}.jpg'.format(img_id)
xml_path = osp.join(self.img_prefix, 'Annotations',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
root = tree.getroot()
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
img_infos.append(
dict(id=img_id, filename=filename, width=width, height=height))
return img_infos
def get_ann_info(self, idx):
img_id = self.img_infos[idx]['id']
xml_path = osp.join(self.img_prefix, 'Annotations',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
root = tree.getroot()
bboxes = []
labels = []
bboxes_ignore = []
labels_ignore = []
for obj in root.findall('object'):
name = obj.find('name').text
label = self.cat2label[name]
difficult = int(obj.find('difficult').text)
bnd_box = obj.find('bndbox')
bbox = [
int(bnd_box.find('xmin').text),
int(bnd_box.find('ymin').text),
int(bnd_box.find('xmax').text),
int(bnd_box.find('ymax').text)
]
if difficult:
bboxes_ignore.append(bbox)
labels_ignore.append(label)
else:
bboxes.append(bbox)
labels.append(label)
if not bboxes:
bboxes = np.zeros((0, 4))
labels = np.zeros((0, ))
else:
bboxes = np.array(bboxes, ndmin=2) - 1
labels = np.array(labels)
if not bboxes_ignore:
bboxes_ignore = np.zeros((0, 4))
labels_ignore = np.zeros((0, ))
else:
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
labels_ignore = np.array(labels_ignore)
ann = dict(
bboxes=bboxes.astype(np.float32),
labels=labels.astype(np.int64),
bboxes_ignore=bboxes_ignore.astype(np.float32),
labels_ignore=labels_ignore.astype(np.int64))
return ann
from .resnet import ResNet
from .resnext import ResNeXt
__all__ = ['ResNet']
__all__ = ['ResNet', 'ResNeXt']
......@@ -42,7 +42,7 @@ class BasicBlock(nn.Module):
assert not with_cp
def forward(self, x):
residual = x
identity = x
out = self.conv1(x)
out = self.bn1(out)
......@@ -52,9 +52,9 @@ class BasicBlock(nn.Module):
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
identity = self.downsample(x)
out += residual
out += identity
out = self.relu(out)
return out
......@@ -71,25 +71,31 @@ class Bottleneck(nn.Module):
downsample=None,
style='pytorch',
with_cp=False):
"""Bottleneck block.
"""Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe']
self.inplanes = inplanes
self.planes = planes
if style == 'pytorch':
conv1_stride = 1
conv2_stride = stride
self.conv1_stride = 1
self.conv2_stride = stride
else:
conv1_stride = stride
conv2_stride = 1
self.conv1_stride = stride
self.conv2_stride = 1
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
inplanes,
planes,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=conv2_stride,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
......@@ -108,7 +114,7 @@ class Bottleneck(nn.Module):
def forward(self, x):
def _inner_forward(x):
residual = x
identity = x
out = self.conv1(x)
out = self.bn1(out)
......@@ -122,9 +128,9 @@ class Bottleneck(nn.Module):
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
identity = self.downsample(x)
out += residual
out += identity
return out
......@@ -219,20 +225,24 @@ class ResNet(nn.Module):
super(ResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError('invalid depth {} for resnet'.format(depth))
self.depth = depth
self.num_stages = num_stages
assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages]
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.frozen_stages = frozen_stages
self.bn_eval = bn_eval
self.bn_frozen = bn_frozen
self.with_cp = with_cp
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
......@@ -240,12 +250,12 @@ class ResNet(nn.Module):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.res_layers = []
for i, num_blocks in enumerate(stage_blocks):
for i, num_blocks in enumerate(self.stage_blocks):
stride = strides[i]
dilation = dilations[i]
planes = 64 * 2**i
res_layer = make_res_layer(
block,
self.block,
self.inplanes,
planes,
num_blocks,
......@@ -253,12 +263,13 @@ class ResNet(nn.Module):
dilation=dilation,
style=self.style,
with_cp=with_cp)
self.inplanes = planes * block.expansion
self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
self.feat_dim = self.block.expansion * 64 * 2**(
len(self.stage_blocks) - 1)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment