Commit 4182170a authored by zhe chen's avatar zhe chen
Browse files

Add frozen_stages kwarg to InternImage backbone (#283)

parent fe6cdd2e
...@@ -6,4 +6,4 @@ segmentation/convertor/ ...@@ -6,4 +6,4 @@ segmentation/convertor/
checkpoint_dir/ checkpoint_dir/
demo/ demo/
pretrained/ pretrained/
upload.py upload.py
\ No newline at end of file
...@@ -575,6 +575,7 @@ class InternImage(nn.Module): ...@@ -575,6 +575,7 @@ class InternImage(nn.Module):
center_feature_scale=False, # for InternImage-H/G center_feature_scale=False, # for InternImage-H/G
use_dcn_v4_op=False, use_dcn_v4_op=False,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
frozen_stages=-1,
init_cfg=None, init_cfg=None,
**kwargs): **kwargs):
super().__init__() super().__init__()
...@@ -588,6 +589,8 @@ class InternImage(nn.Module): ...@@ -588,6 +589,8 @@ class InternImage(nn.Module):
self.init_cfg = init_cfg self.init_cfg = init_cfg
self.out_indices = out_indices self.out_indices = out_indices
self.level2_post_norm_block_ids = level2_post_norm_block_ids self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.frozen_stages = frozen_stages
logger = get_root_logger() logger = get_root_logger()
logger.info(f'using core type: {core_op}') logger.info(f'using core type: {core_op}')
logger.info(f'using activation layer: {act_layer}') logger.info(f'using activation layer: {act_layer}')
...@@ -642,6 +645,19 @@ class InternImage(nn.Module): ...@@ -642,6 +645,19 @@ class InternImage(nn.Module):
self.num_layers = len(depths) self.num_layers = len(depths)
self.apply(self._init_weights) self.apply(self._init_weights)
self.apply(self._init_deform_weights) self.apply(self._init_deform_weights)
self._freeze_stages()
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer frozen."""
super(InternImage, self).train(mode)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
for level in self.levels[:self.frozen_stages]:
level.eval()
for param in level.parameters():
param.requires_grad = False
def init_weights(self): def init_weights(self):
logger = get_root_logger() logger = get_root_logger()
......
...@@ -575,6 +575,7 @@ class InternImage(nn.Module): ...@@ -575,6 +575,7 @@ class InternImage(nn.Module):
center_feature_scale=False, # for InternImage-H/G center_feature_scale=False, # for InternImage-H/G
use_dcn_v4_op=False, use_dcn_v4_op=False,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
frozen_stages=-1,
init_cfg=None, init_cfg=None,
**kwargs): **kwargs):
super().__init__() super().__init__()
...@@ -588,6 +589,8 @@ class InternImage(nn.Module): ...@@ -588,6 +589,8 @@ class InternImage(nn.Module):
self.init_cfg = init_cfg self.init_cfg = init_cfg
self.out_indices = out_indices self.out_indices = out_indices
self.level2_post_norm_block_ids = level2_post_norm_block_ids self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.frozen_stages = frozen_stages
logger = get_root_logger() logger = get_root_logger()
logger.info(f'using core type: {core_op}') logger.info(f'using core type: {core_op}')
logger.info(f'using activation layer: {act_layer}') logger.info(f'using activation layer: {act_layer}')
...@@ -642,6 +645,19 @@ class InternImage(nn.Module): ...@@ -642,6 +645,19 @@ class InternImage(nn.Module):
self.num_layers = len(depths) self.num_layers = len(depths)
self.apply(self._init_weights) self.apply(self._init_weights)
self.apply(self._init_deform_weights) self.apply(self._init_deform_weights)
self._freeze_stages()
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer frozen."""
super(InternImage, self).train(mode)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
for level in self.levels[:self.frozen_stages]:
level.eval()
for param in level.parameters():
param.requires_grad = False
def init_weights(self): def init_weights(self):
logger = get_root_logger() logger = get_root_logger()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment