Unverified Commit 0ab4916a authored by Ming Yu's avatar Ming Yu Committed by GitHub
Browse files

[Model Compression] Add replace module (#4492)

parent aa1f71c8
......@@ -11,6 +11,7 @@ _logger = logging.getLogger(__name__)
replace_module = {
'BatchNorm2d': lambda module, masks: replace_batchnorm2d(module, masks),
'BatchNorm1d': lambda module, masks: replace_batchnorm1d(module, masks),
'InstanceNorm2d': lambda module, masks: replace_instancenorm2d(module, masks),
'Conv2d': lambda module, masks: replace_conv2d(module, masks),
'Linear': lambda module, masks: replace_linear(module, masks),
'MaxPool2d': lambda module, masks: no_replace(module, masks),
......@@ -43,6 +44,7 @@ replace_module = {
'Upsample': lambda module, masks: no_replace(module, masks),
'LayerNorm': lambda module, masks: replace_layernorm(module, masks),
'ConvTranspose2d': lambda module, masks: replace_convtranspose2d(module, masks),
'PixelShuffle': lambda module, masks: replace_pixelshuffle(module, masks),
'Flatten': lambda module, masks: no_replace(module, masks)
}
......@@ -280,6 +282,51 @@ def replace_batchnorm2d(norm, masks):
return new_norm
def replace_instancenorm2d(norm, masks):
"""
Parameters
----------
norm : torch.nn.InstanceNorm2d
The instancenorm module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.InstanceNorm2d
The new instancenorm module
"""
in_masks, output_mask, _ = masks
assert isinstance(norm, nn.InstanceNorm2d)
in_mask = in_masks[0]
# N, C, H, W
_, remained_in = convert_to_coarse_mask(in_mask, 1)
_, remained_out = convert_to_coarse_mask(output_mask, 1)
if remained_in.size(0) != remained_out.size(0):
raise ShapeMisMatchError()
num_features = remained_in.size(0)
_logger.info("replace instancenorm2d with num_features: %d", num_features)
new_norm = torch.nn.InstanceNorm2d(num_features=num_features,
eps=norm.eps,
momentum=norm.momentum,
affine=norm.affine,
track_running_stats=norm.track_running_stats)
# assign weights
if norm.affine:
new_norm.weight.data = torch.index_select(norm.weight.data, 0, remained_in)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, remained_in)
if norm.track_running_stats:
new_norm.running_mean.data = torch.index_select(
norm.running_mean.data, 0, remained_in)
new_norm.running_var.data = torch.index_select(
norm.running_var.data, 0, remained_in)
return new_norm
def replace_conv2d(conv, masks):
"""
Replace the original conv with a new one according to the infered
......@@ -544,3 +591,41 @@ def replace_layernorm(layernorm, masks):
new_shape.append(n_remained)
return nn.LayerNorm(tuple(new_shape), layernorm.eps, layernorm.elementwise_affine)
def replace_pixelshuffle(pixelshuffle, masks):
"""
Parameters
----------
norm : torch.nn.PixelShuffle
The pixelshuffle module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.PixelShuffle
The new pixelshuffle module
"""
in_masks, output_mask, _ = masks
assert isinstance(pixelshuffle, torch.nn.PixelShuffle)
if len(in_masks) != 1:
raise InputsNumberError()
in_mask = in_masks[0]
# N, C, H, W
_, remained_in = convert_to_coarse_mask(in_mask, 1)
_, remained_out = convert_to_coarse_mask(output_mask, 1)
upscale_factor = pixelshuffle.upscale_factor
if remained_in.size(0) % (upscale_factor * upscale_factor):
_logger.debug("Shape mismatch, remained_in:%d upscale_factor:%d",
remained_in.size(0), remained_out.size(0))
raise ShapeMisMatchError()
if remained_out.size(0) * upscale_factor * upscale_factor != remained_in:
raise ShapeMisMatchError()
new_pixelshuffle = torch.nn.PixelShuffle(upscale_factor)
return new_pixelshuffle
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment