Commit 4169abc1 authored by Haricharan Lakshman's avatar Haricharan Lakshman Committed by Facebook GitHub Bot
Browse files

Freeze matched bn layers

Summary:
Convert the batchnorm layers that match the specified regular expressions to FrozenBatchNorm2d.

If module is an instance of batchnorm and it matches the reg exps, returns a new FrozenBatchNorm2d module.

Otherwise, in-place converts the matching batchnorm child modules to FrozenBatchNorm2d
and returns the main module.

Reviewed By: ppwwyyxx

Differential Revision: D29286500

fbshipit-source-id: 3a20f5eeff59ddff50c42fe297eedf0ce2b909bc
parent 77ef0db7
......@@ -3,6 +3,9 @@
import re
import logging
import torch.nn as nn
from detectron2.layers import FrozenBatchNorm2d
logger = logging.getLogger(__name__)
......@@ -33,3 +36,61 @@ def set_requires_grad(model, reg_exps, value):
value, matched_parameter_names))
logger.info("Unmatched layers: {}".format(unmatched_parameter_names))
return matched_parameter_names, unmatched_parameter_names
def _freeze_matched_bn(module, name, reg_exps, matched_names, unmatched_names):
"""
Recursive function to freeze bn layers that match specified regular expressions.
"""
res = module
# Base case: current module is a leaf node
if len(list(module.children())) == 0:
if isinstance(module, nn.modules.batchnorm._BatchNorm):
matched = False
for frozen_layers_regex in reg_exps:
if re.match(frozen_layers_regex, name):
matched = True
matched_names.append(name)
# Convert to frozen batch norm
res = FrozenBatchNorm2d.convert_frozen_batchnorm(module)
if not matched:
unmatched_names.append(name)
return res
# Recursion: current module has children
for child_name, child in module.named_children():
_name = name + "." + child_name if name != "" else child_name
new_child = _freeze_matched_bn(
child, _name, reg_exps, matched_names, unmatched_names
)
if new_child is not child:
res.add_module(child_name, new_child)
return res
def freeze_matched_bn(module, reg_exps):
"""
Convert matching batchnorm layers in module into FrozenBatchNorm2d.
Args:
module: nn.Module
reg_exps: list of regular expressions to match
Returns:
If module is an instance of batchnorm and it matches the reg exps,
returns a new FrozenBatchNorm2d module.
Otherwise, in-place converts the matching batchnorm child modules to FrozenBatchNorm2d
and returns the main module.
"""
matched_names = []
unmatched_names = []
res = _freeze_matched_bn(module, "", reg_exps, matched_names, unmatched_names)
logger.info("Matched BN layers are frozen: {}".format(matched_names))
logger.info("Unmatched BN layers: {}".format(unmatched_names))
return res
......@@ -31,6 +31,7 @@ from d2go.data.utils import (
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.model_freezing_utils import (
freeze_matched_bn,
set_requires_grad,
)
from d2go.modeling.quantization import (
......@@ -249,6 +250,7 @@ class Detectron2GoRunner(BaseRunner):
if cfg.MODEL.FROZEN_LAYER_REG_EXP:
set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False)
model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP)
if cfg.QUANTIZATION.QAT.ENABLED:
# Disable fake_quant and observer so that the model will be trained normally
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment