Unverified Commit d452a166 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

update lottery ticket pruner based on refactored compression code (#1989)

parent 6b0ecee6
...@@ -71,6 +71,8 @@ if __name__ == '__main__': ...@@ -71,6 +71,8 @@ if __name__ == '__main__':
pruner = LotteryTicketPruner(model, configure_list, optimizer) pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress() pruner.compress()
#model = nn.DataParallel(model)
for i in pruner.get_prune_iterations(): for i in pruner.get_prune_iterations():
pruner.prune_iteration_start() pruner.prune_iteration_start()
loss = 0 loss = 0
......
...@@ -69,7 +69,7 @@ if __name__ == '__main__': ...@@ -69,7 +69,7 @@ if __name__ == '__main__':
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=10, drop_last=False) train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=10, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=10, drop_last=True) test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=10, drop_last=True)
device = torch.device("cuda: 0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = fc1() model = fc1()
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
......
...@@ -41,6 +41,7 @@ class Compressor: ...@@ -41,6 +41,7 @@ class Compressor:
self.modules_to_compress = None self.modules_to_compress = None
self.modules_wrapper = None self.modules_wrapper = None
self.buffers = {} self.buffers = {}
self.is_wrapped = False
def detect_modules_to_compress(self): def detect_modules_to_compress(self):
""" """
...@@ -63,6 +64,7 @@ class Compressor: ...@@ -63,6 +64,7 @@ class Compressor:
""" """
for wrapper in reversed(self.get_modules_wrapper()): for wrapper in reversed(self.get_modules_wrapper()):
_setattr(self.bound_model, wrapper.name, wrapper) _setattr(self.bound_model, wrapper.name, wrapper)
self.is_wrapped = True
def _unwrap_model(self): def _unwrap_model(self):
""" """
...@@ -71,6 +73,7 @@ class Compressor: ...@@ -71,6 +73,7 @@ class Compressor:
""" """
for wrapper in self.get_modules_wrapper(): for wrapper in self.get_modules_wrapper():
_setattr(self.bound_model, wrapper.name, wrapper.module) _setattr(self.bound_model, wrapper.name, wrapper.module)
self.is_wrapped = False
def compress(self): def compress(self):
""" """
...@@ -263,7 +266,7 @@ class Pruner(Compressor): ...@@ -263,7 +266,7 @@ class Pruner(Compressor):
def __init__(self, model, config_list): def __init__(self, model, config_list):
super().__init__(model, config_list) super().__init__(model, config_list)
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Pruners should overload this method to provide mask for weight tensors. Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight. The mask must have the same shape and type comparing to the weight.
...@@ -291,9 +294,12 @@ class Pruner(Compressor): ...@@ -291,9 +294,12 @@ class Pruner(Compressor):
the configuration for generating the mask the configuration for generating the mask
""" """
_logger.info("compressing module %s.", layer.name) _logger.info("compressing module %s.", layer.name)
return PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
assert hasattr(layer.module, 'weight')
wrapper.to(layer.module.weight.device)
return wrapper
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None): def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None):
""" """
Export pruned model weights, masks and onnx model(optional) Export pruned model weights, masks and onnx model(optional)
...@@ -307,6 +313,9 @@ class Pruner(Compressor): ...@@ -307,6 +313,9 @@ class Pruner(Compressor):
(optional) path to save onnx model (optional) path to save onnx model
input_shape : list or tuple input_shape : list or tuple
input shape to onnx model input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
""" """
# if self.detect_modules_to_compress() and not self.mask_dict: # if self.detect_modules_to_compress() and not self.mask_dict:
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks') # _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
...@@ -335,12 +344,29 @@ class Pruner(Compressor): ...@@ -335,12 +344,29 @@ class Pruner(Compressor):
if onnx_path is not None: if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model' assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed # input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape) input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data, onnx_path) torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) _logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
self._wrap_model() self._wrap_model()
def load_model_state_dict(self, model_state):
"""
Load the state dict saved from unwrapped model.
Parameters:
-----------
model_state : dict
state dict saved from unwrapped model
"""
if self.is_wrapped:
self._unwrap_model()
self.bound_model.load_state_dict(model_state)
self._wrap_model()
else:
self.bound_model.load_state_dict(model_state)
class QuantizerModuleWrapper(torch.nn.Module): class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer): def __init__(self, module, module_name, module_type, config, quantizer):
......
...@@ -290,38 +290,23 @@ class LotteryTicketPruner(Pruner): ...@@ -290,38 +290,23 @@ class LotteryTicketPruner(Pruner):
prune_iterations = config['prune_iterations'] prune_iterations = config['prune_iterations']
return prune_iterations return prune_iterations
def _print_masks(self, print_mask=False):
torch.set_printoptions(threshold=1000)
for op_name in self.mask_dict.keys():
mask = self.mask_dict[op_name]
print('op name: ', op_name)
if print_mask:
print('mask: ', mask)
# calculate current sparsity
mask_num = mask['weight'].sum().item()
mask_size = mask['weight'].numel()
print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default')
def _calc_sparsity(self, sparsity): def _calc_sparsity(self, sparsity):
keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations) keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations)
curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration
return max(1 - curr_keep_ratio, 0) return max(1 - curr_keep_ratio, 0)
def _calc_mask(self, weight, sparsity, op_name): def _calc_mask(self, weight, sparsity, curr_w_mask):
if self.curr_prune_iteration == 0: if self.curr_prune_iteration == 0:
mask = torch.ones(weight.shape).type_as(weight) mask = torch.ones(weight.shape).type_as(weight)
else: else:
curr_sparsity = self._calc_sparsity(sparsity) curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None w_abs = weight.abs() * curr_w_mask
curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask['weight']
k = int(w_abs.numel() * curr_sparsity) k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight) mask = torch.gt(w_abs, threshold).type_as(weight)
return {'weight': mask} return {'weight': mask}
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Generate mask for the given ``weight``. Generate mask for the given ``weight``.
...@@ -331,15 +316,17 @@ class LotteryTicketPruner(Pruner): ...@@ -331,15 +316,17 @@ class LotteryTicketPruner(Pruner):
The layer to be pruned The layer to be pruned
config : dict config : dict
Pruning configurations for this weight Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns Returns
------- -------
tensor tensor
The mask for this weight The mask for this weight, it is ```None``` because this pruner
calculates and assigns masks in ```prune_iteration_start```,
no need to do anything in this function.
""" """
assert self.mask_dict.get(layer.name) is not None, 'Please call iteration_start before training' return None
mask = self.mask_dict[layer.name]
return mask
def get_prune_iterations(self): def get_prune_iterations(self):
""" """
...@@ -364,16 +351,26 @@ class LotteryTicketPruner(Pruner): ...@@ -364,16 +351,26 @@ class LotteryTicketPruner(Pruner):
self.curr_prune_iteration += 1 self.curr_prune_iteration += 1
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations' assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'
modules_wrapper = self.get_modules_wrapper()
modules_to_compress = self.detect_modules_to_compress() modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
module_wrapper = None
for wrapper in modules_wrapper:
if wrapper.name == layer.name:
module_wrapper = wrapper
break
assert module_wrapper is not None
sparsity = config.get('sparsity') sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, layer.name) mask = self._calc_mask(layer.module.weight.data, sparsity, module_wrapper.weight_mask)
self.mask_dict.update({layer.name: mask}) # TODO: directly use weight_mask is not good
self._print_masks() module_wrapper.weight_mask.copy_(mask['weight'])
# there is no mask for bias
# reinit weights back to original after new masks are generated # reinit weights back to original after new masks are generated
if self.reset_weights: if self.reset_weights:
self._model.load_state_dict(self._model_state) # should use this member function to reset model weights
self.load_model_state_dict(self._model_state)
self._optimizer.load_state_dict(self._optimizer_state) self._optimizer.load_state_dict(self._optimizer_state)
if self._lr_scheduler is not None: if self._lr_scheduler is not None:
self._lr_scheduler.load_state_dict(self._scheduler_state) self._lr_scheduler.load_state_dict(self._scheduler_state)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment