Commit 4182170a authored by zhe chen's avatar zhe chen
Browse files

Add frozen_stages kwarg to InternImage backbone (#283)

parent fe6cdd2e
......@@ -6,4 +6,4 @@ segmentation/convertor/
checkpoint_dir/
demo/
pretrained/
upload.py
\ No newline at end of file
upload.py
......@@ -575,6 +575,7 @@ class InternImage(nn.Module):
center_feature_scale=False, # for InternImage-H/G
use_dcn_v4_op=False,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
init_cfg=None,
**kwargs):
super().__init__()
......@@ -588,6 +589,8 @@ class InternImage(nn.Module):
self.init_cfg = init_cfg
self.out_indices = out_indices
self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.frozen_stages = frozen_stages
logger = get_root_logger()
logger.info(f'using core type: {core_op}')
logger.info(f'using activation layer: {act_layer}')
......@@ -642,6 +645,19 @@ class InternImage(nn.Module):
self.num_layers = len(depths)
self.apply(self._init_weights)
self.apply(self._init_deform_weights)
self._freeze_stages()
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer frozen."""
super(InternImage, self).train(mode)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
for level in self.levels[:self.frozen_stages]:
level.eval()
for param in level.parameters():
param.requires_grad = False
def init_weights(self):
logger = get_root_logger()
......
......@@ -575,6 +575,7 @@ class InternImage(nn.Module):
center_feature_scale=False, # for InternImage-H/G
use_dcn_v4_op=False,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
init_cfg=None,
**kwargs):
super().__init__()
......@@ -588,6 +589,8 @@ class InternImage(nn.Module):
self.init_cfg = init_cfg
self.out_indices = out_indices
self.level2_post_norm_block_ids = level2_post_norm_block_ids
self.frozen_stages = frozen_stages
logger = get_root_logger()
logger.info(f'using core type: {core_op}')
logger.info(f'using activation layer: {act_layer}')
......@@ -642,6 +645,19 @@ class InternImage(nn.Module):
self.num_layers = len(depths)
self.apply(self._init_weights)
self.apply(self._init_deform_weights)
self._freeze_stages()
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer frozen."""
super(InternImage, self).train(mode)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
for level in self.levels[:self.frozen_stages]:
level.eval()
for param in level.parameters():
param.requires_grad = False
def init_weights(self):
logger = get_root_logger()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment