Unverified Commit fb5ef932 authored by Cjkkkk's avatar Cjkkkk Committed by GitHub
Browse files

fix model compression config validation (#2033)

parent 50e425f2
...@@ -149,13 +149,24 @@ class Compressor: ...@@ -149,13 +149,24 @@ class Compressor:
ret = None ret = None
for config in self.config_list: for config in self.config_list:
config = config.copy() config = config.copy()
config['op_types'] = self._expand_config_op_types(config) # expand config if key `default` is in config['op_types']
if layer.type not in config['op_types']: if 'op_types' in config and 'default' in config['op_types']:
expanded_op_types = []
for op_type in config['op_types']:
if op_type == 'default':
expanded_op_types.extend(default_layers.weighted_modules)
else:
expanded_op_types.append(op_type)
config['op_types'] = expanded_op_types
# check if condition is satisified
if 'op_types' in config and layer.type not in config['op_types']:
continue continue
if config.get('op_names') and layer.name not in config['op_names']: if 'op_names' in config and layer.name not in config['op_names']:
continue continue
ret = config ret = config
if ret is None or ret.get('exclude'): if ret is None or 'exclude' in ret:
return None return None
return ret return ret
...@@ -188,16 +199,6 @@ class Compressor: ...@@ -188,16 +199,6 @@ class Compressor:
""" """
raise NotImplementedError() raise NotImplementedError()
def _expand_config_op_types(self, config):
if config is None:
return []
expanded_op_types = []
for op_type in config.get('op_types', []):
if op_type == 'default':
expanded_op_types.extend(default_layers.weighted_modules)
else:
expanded_op_types.append(op_type)
return expanded_op_types
class PrunerModuleWrapper(torch.nn.Module): class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner): def __init__(self, module, module_name, module_type, config, pruner):
...@@ -229,11 +230,12 @@ class PrunerModuleWrapper(torch.nn.Module): ...@@ -229,11 +230,12 @@ class PrunerModuleWrapper(torch.nn.Module):
# register buffer for mask # register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape)) self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
self.registered_buffers['weight_mask'] = self.weight_mask
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else: else:
self.register_buffer("bias_mask", None) self.register_buffer("bias_mask", None)
self.registered_buffers['weight_mask'] = self.weight_mask
self.registered_buffers['bias_mask'] = self.bias_mask self.registered_buffers['bias_mask'] = self.bias_mask
# register user specified buffer # register user specified buffer
for name in self.pruner.buffers: for name in self.pruner.buffers:
...@@ -297,7 +299,8 @@ class Pruner(Compressor): ...@@ -297,7 +299,8 @@ class Pruner(Compressor):
""" """
_logger.info("compressing module %s.", layer.name) _logger.info("compressing module %s.", layer.name)
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
assert hasattr(layer.module, 'weight') assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
# move newly registered buffers to the same device of weight
wrapper.to(layer.module.weight.device) wrapper.to(layer.module.weight.device)
return wrapper return wrapper
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment