"docker/install/vscode:/vscode.git/clone" did not exist on "cf9ba90fd47471b56efaa66128de78667b6f17b9"
Commit bb6ef3b3 authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'dev' into mask-debug

parents 98b20b9b 5266dea0
...@@ -17,6 +17,18 @@ class BaseDetector(nn.Module): ...@@ -17,6 +17,18 @@ class BaseDetector(nn.Module):
def __init__(self): def __init__(self):
super(BaseDetector, self).__init__() super(BaseDetector, self).__init__()
@property
def with_neck(self):
return hasattr(self, 'neck') and self.neck is not None
@property
def with_bbox(self):
return hasattr(self, 'bbox_head') and self.bbox_head is not None
@property
def with_mask(self):
return hasattr(self, 'mask_head') and self.mask_head is not None
@abstractmethod @abstractmethod
def extract_feat(self, imgs): def extract_feat(self, imgs):
pass pass
......
...@@ -26,13 +26,13 @@ class RPN(BaseDetector, RPNTestMixin): ...@@ -26,13 +26,13 @@ class RPN(BaseDetector, RPNTestMixin):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
super(RPN, self).init_weights(pretrained) super(RPN, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained) self.backbone.init_weights(pretrained=pretrained)
if self.neck is not None: if self.with_neck:
self.neck.init_weights() self.neck.init_weights()
self.rpn_head.init_weights() self.rpn_head.init_weights()
def extract_feat(self, img): def extract_feat(self, img):
x = self.backbone(img) x = self.backbone(img)
if self.neck is not None: if self.with_neck:
x = self.neck(x) x = self.neck(x)
return x return x
......
...@@ -25,23 +25,19 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -25,23 +25,19 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
if neck is not None: if neck is not None:
self.with_neck = True
self.neck = builder.build_neck(neck) self.neck = builder.build_neck(neck)
else: else:
raise NotImplementedError raise NotImplementedError
self.with_rpn = True if rpn_head is not None else False if rpn_head is not None:
if self.with_rpn:
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_rpn_head(rpn_head)
self.with_bbox = True if bbox_head is not None else False if bbox_head is not None:
if self.with_bbox:
self.bbox_roi_extractor = builder.build_roi_extractor( self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor) bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head) self.bbox_head = builder.build_bbox_head(bbox_head)
self.with_mask = True if mask_head is not None else False if mask_head is not None:
if self.with_mask:
self.mask_roi_extractor = builder.build_roi_extractor( self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor) mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head) self.mask_head = builder.build_mask_head(mask_head)
...@@ -51,6 +47,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -51,6 +47,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
@property
def with_rpn(self):
return hasattr(self, 'rpn_head') and self.rpn_head is not None
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
super(TwoStageDetector, self).init_weights(pretrained) super(TwoStageDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained) self.backbone.init_weights(pretrained=pretrained)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment