Unverified Commit 6e643b00 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix pruning examples & pruner memory usage optimize (#4412)

parent f46f0cf4
...@@ -72,9 +72,9 @@ def evaluator(model): ...@@ -72,9 +72,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc)) print('Accuracy: {}%\n'.format(acc))
return acc return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4): def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay) optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1) scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler return optimizer, scheduler
if __name__ == '__main__': if __name__ == '__main__':
...@@ -90,7 +90,7 @@ if __name__ == '__main__': ...@@ -90,7 +90,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device) model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model) optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0 pre_best_acc = 0.0
best_state_dict = None best_state_dict = None
...@@ -117,9 +117,9 @@ if __name__ == '__main__': ...@@ -117,9 +117,9 @@ if __name__ == '__main__':
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize # make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) traced_optimizer = trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
if 'apoz' in args.pruner: if 'apoz' in args.pruner:
pruner = ActivationAPoZRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=1) pruner = ActivationAPoZRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
else: else:
pruner = ActivationMeanRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=1) pruner = ActivationMeanRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
_, masks = pruner.compress() _, masks = pruner.compress()
pruner.show_pruned_weights() pruner.show_pruned_weights()
pruner._unwrap_model() pruner._unwrap_model()
...@@ -129,7 +129,7 @@ if __name__ == '__main__': ...@@ -129,7 +129,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage. # Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01) optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0 best_acc = 0.0
g_epoch = 0 g_epoch = 0
......
...@@ -71,9 +71,9 @@ def evaluator(model): ...@@ -71,9 +71,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc)) print('Accuracy: {}%\n'.format(acc))
return acc return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4): def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay) optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1) scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler return optimizer, scheduler
if __name__ == '__main__': if __name__ == '__main__':
...@@ -86,7 +86,7 @@ if __name__ == '__main__': ...@@ -86,7 +86,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device) model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model) optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0 pre_best_acc = 0.0
best_state_dict = None best_state_dict = None
...@@ -125,7 +125,7 @@ if __name__ == '__main__': ...@@ -125,7 +125,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage. # Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01) optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0 best_acc = 0.0
g_epoch = 0 g_epoch = 0
......
...@@ -71,9 +71,9 @@ def evaluator(model): ...@@ -71,9 +71,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc)) print('Accuracy: {}%\n'.format(acc))
return acc return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4): def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay) optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1) scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler return optimizer, scheduler
if __name__ == '__main__': if __name__ == '__main__':
...@@ -86,7 +86,7 @@ if __name__ == '__main__': ...@@ -86,7 +86,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device) model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model) optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0 pre_best_acc = 0.0
best_state_dict = None best_state_dict = None
...@@ -119,7 +119,7 @@ if __name__ == '__main__': ...@@ -119,7 +119,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage. # Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01) optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0 best_acc = 0.0
for i in range(args.fine_tune_epochs): for i in range(args.fine_tune_epochs):
......
...@@ -70,9 +70,9 @@ def evaluator(model): ...@@ -70,9 +70,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc)) print('Accuracy: {}%\n'.format(acc))
return acc return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4): def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay) optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1) scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler return optimizer, scheduler
if __name__ == '__main__': if __name__ == '__main__':
...@@ -85,7 +85,7 @@ if __name__ == '__main__': ...@@ -85,7 +85,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device) model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model) optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0 pre_best_acc = 0.0
best_state_dict = None best_state_dict = None
...@@ -117,7 +117,7 @@ if __name__ == '__main__': ...@@ -117,7 +117,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage. # Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01) optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0 best_acc = 0.0
g_epoch = 0 g_epoch = 0
......
...@@ -71,9 +71,9 @@ def evaluator(model): ...@@ -71,9 +71,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc)) print('Accuracy: {}%\n'.format(acc))
return acc return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4): def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay) optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1) scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler return optimizer, scheduler
if __name__ == '__main__': if __name__ == '__main__':
...@@ -89,7 +89,7 @@ if __name__ == '__main__': ...@@ -89,7 +89,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device) model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model) optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0 pre_best_acc = 0.0
best_state_dict = None best_state_dict = None
...@@ -125,7 +125,7 @@ if __name__ == '__main__': ...@@ -125,7 +125,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage. # Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01) optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0 best_acc = 0.0
for i in range(args.fine_tune_epochs): for i in range(args.fine_tune_epochs):
......
...@@ -72,9 +72,9 @@ def evaluator(model): ...@@ -72,9 +72,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc)) print('Accuracy: {}%\n'.format(acc))
return acc return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4): def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay) optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1) scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler return optimizer, scheduler
if __name__ == '__main__': if __name__ == '__main__':
...@@ -87,7 +87,7 @@ if __name__ == '__main__': ...@@ -87,7 +87,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device) model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model) optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0 pre_best_acc = 0.0
best_state_dict = None best_state_dict = None
...@@ -124,7 +124,7 @@ if __name__ == '__main__': ...@@ -124,7 +124,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage. # Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01) optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0 best_acc = 0.0
g_epoch = 0 g_epoch = 0
for i in range(args.fine_tune_epochs): for i in range(args.fine_tune_epochs):
......
...@@ -72,9 +72,9 @@ def evaluator(model): ...@@ -72,9 +72,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc)) print('Accuracy: {}%\n'.format(acc))
return acc return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4): def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay) optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1) scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler return optimizer, scheduler
if __name__ == '__main__': if __name__ == '__main__':
...@@ -87,7 +87,7 @@ if __name__ == '__main__': ...@@ -87,7 +87,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device) model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model) optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0 pre_best_acc = 0.0
best_state_dict = None best_state_dict = None
...@@ -113,7 +113,7 @@ if __name__ == '__main__': ...@@ -113,7 +113,7 @@ if __name__ == '__main__':
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize # make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) traced_optimizer = trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
pruner = TaylorFOWeightPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=1) pruner = TaylorFOWeightPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
_, masks = pruner.compress() _, masks = pruner.compress()
pruner.show_pruned_weights() pruner.show_pruned_weights()
pruner._unwrap_model() pruner._unwrap_model()
...@@ -123,7 +123,7 @@ if __name__ == '__main__': ...@@ -123,7 +123,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage. # Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01) optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0 best_acc = 0.0
g_epoch = 0 g_epoch = 0
......
...@@ -524,11 +524,23 @@ class ActivationPruner(BasicPruner): ...@@ -524,11 +524,23 @@ class ActivationPruner(BasicPruner):
raise 'Unsupported activatoin {}'.format(activation) raise 'Unsupported activatoin {}'.format(activation)
def _collector(self, buffer: List) -> Callable[[Module, Tensor, Tensor], None]: def _collector(self, buffer: List) -> Callable[[Module, Tensor, Tensor], None]:
assert len(buffer) == 0, 'Buffer pass to activation pruner collector is not empty.'
# The length of the buffer used in this pruner will always be 2.
# buffer[0] is the number of how many batches are counted in buffer[1].
# buffer[1] is a tensor and the size of buffer[1] is same as the activation.
buffer.append(0)
def collect_activation(_module: Module, _input: Tensor, output: Tensor): def collect_activation(_module: Module, _input: Tensor, output: Tensor):
if len(buffer) < self.training_batches: if len(buffer) == 1:
buffer.append(self._activation(output.detach())) buffer.append(torch.zeros_like(output))
if buffer[0] < self.training_batches:
buffer[1] += self._activation_trans(output)
buffer[0] += 1
return collect_activation return collect_activation
def _activation_trans(self, output: Tensor) -> Tensor:
raise NotImplementedError()
def reset_tools(self): def reset_tools(self):
collector_info = HookCollectorInfo([layer_info for layer_info, _ in self._detect_modules_to_compress()], 'forward', self._collector) collector_info = HookCollectorInfo([layer_info for layer_info, _ in self._detect_modules_to_compress()], 'forward', self._collector)
if self.data_collector is None: if self.data_collector is None:
...@@ -551,11 +563,19 @@ class ActivationPruner(BasicPruner): ...@@ -551,11 +563,19 @@ class ActivationPruner(BasicPruner):
class ActivationAPoZRankPruner(ActivationPruner): class ActivationAPoZRankPruner(ActivationPruner):
def _activation_trans(self, output: Tensor) -> Tensor:
# return a matrix that the position of zero in `output` is one, others is zero.
return torch.eq(self._activation(output.detach()), torch.zeros_like(output)).type_as(output)
def _get_metrics_calculator(self) -> MetricsCalculator: def _get_metrics_calculator(self) -> MetricsCalculator:
return APoZRankMetricsCalculator(dim=1) return APoZRankMetricsCalculator(dim=1)
class ActivationMeanRankPruner(ActivationPruner): class ActivationMeanRankPruner(ActivationPruner):
def _activation_trans(self, output: Tensor) -> Tensor:
# return the activation of `output` directly.
return self._activation(output.detach())
def _get_metrics_calculator(self) -> MetricsCalculator: def _get_metrics_calculator(self) -> MetricsCalculator:
return MeanRankMetricsCalculator(dim=1) return MeanRankMetricsCalculator(dim=1)
...@@ -647,9 +667,14 @@ class TaylorFOWeightPruner(BasicPruner): ...@@ -647,9 +667,14 @@ class TaylorFOWeightPruner(BasicPruner):
schema.validate(config_list) schema.validate(config_list)
def _collector(self, buffer: List, weight_tensor: Tensor) -> Callable[[Tensor], None]: def _collector(self, buffer: List, weight_tensor: Tensor) -> Callable[[Tensor], None]:
assert len(buffer) == 0, 'Buffer pass to taylor pruner collector is not empty.'
buffer.append(0)
buffer.append(torch.zeros_like(weight_tensor))
def collect_taylor(grad: Tensor): def collect_taylor(grad: Tensor):
if len(buffer) < self.training_batches: if buffer[0] < self.training_batches:
buffer.append(self._calculate_taylor_expansion(weight_tensor, grad)) buffer[1] += self._calculate_taylor_expansion(weight_tensor, grad)
buffer[0] += 1
return collect_taylor return collect_taylor
def _calculate_taylor_expansion(self, weight_tensor: Tensor, grad: Tensor) -> Tensor: def _calculate_taylor_expansion(self, weight_tensor: Tensor, grad: Tensor) -> Tensor:
......
...@@ -75,19 +75,20 @@ class NormMetricsCalculator(MetricsCalculator): ...@@ -75,19 +75,20 @@ class NormMetricsCalculator(MetricsCalculator):
class MultiDataNormMetricsCalculator(NormMetricsCalculator): class MultiDataNormMetricsCalculator(NormMetricsCalculator):
""" """
Sum each list of tensor in data at first, then calculate the specify norm for each sumed tensor. The data value format is a two-element list [batch_number, cumulative_data].
TaylorFO pruner use this to calculate metric. Directly use the cumulative_data as new_data to calculate norm metric.
TaylorFO pruner uses this to calculate metric.
""" """
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]: def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
new_data = {name: sum(list_tensor) for name, list_tensor in data.items()} new_data = {name: buffer[1] for name, buffer in data.items()}
return super().calculate_metrics(new_data) return super().calculate_metrics(new_data)
class DistMetricsCalculator(MetricsCalculator): class DistMetricsCalculator(MetricsCalculator):
""" """
Calculate the sum of specify distance for each element with all other elements in specify `dim` in each tensor in data. Calculate the sum of specify distance for each element with all other elements in specify `dim` in each tensor in data.
FPGM pruner use this to calculate metric. FPGM pruner uses this to calculate metric.
""" """
def __init__(self, p: float, dim: Union[int, List[int]]): def __init__(self, p: float, dim: Union[int, List[int]]):
...@@ -153,26 +154,23 @@ class DistMetricsCalculator(MetricsCalculator): ...@@ -153,26 +154,23 @@ class DistMetricsCalculator(MetricsCalculator):
class APoZRankMetricsCalculator(MetricsCalculator): class APoZRankMetricsCalculator(MetricsCalculator):
""" """
This metric counts the zero number at the same position in the tensor list in data, The data value format is a two-element list [batch_number, batch_wise_zeros_count_sum].
then sum the zero number on `dim` and calculate the non-zero rate. This metric sum the zero number on `dim` then devide the (batch_number * across_dim_size) to calculate the non-zero rate.
Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance. Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance.
APoZRank pruner use this to calculate metric. APoZRank pruner uses this to calculate metric.
""" """
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]: def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
metrics = {} metrics = {}
for name, tensor_list in data.items(): for name, (num, zero_counts) in data.items():
# NOTE: dim=0 means the batch dim is 0 keeped_dim = list(range(len(zero_counts.size()))) if self.dim is None else self.dim
activations = torch.cat(tensor_list, dim=0) across_dim = list(range(len(zero_counts.size())))
_eq_zero = torch.eq(activations, torch.zeros_like(activations))
keeped_dim = list(range(len(_eq_zero.size()))) if self.dim is None else self.dim
across_dim = list(range(len(_eq_zero.size())))
[across_dim.pop(i) for i in reversed(keeped_dim)] [across_dim.pop(i) for i in reversed(keeped_dim)]
# The element number on each [keeped_dim + 1] in _eq_zero # The element number on each keeped_dim in zero_counts
total_size = 1 total_size = num
for dim, dim_size in enumerate(_eq_zero.size()): for dim, dim_size in enumerate(zero_counts.size()):
if dim not in keeped_dim: if dim not in keeped_dim:
total_size *= dim_size total_size *= dim_size
_apoz = torch.sum(_eq_zero, dim=across_dim).type_as(activations) / total_size _apoz = torch.sum(zero_counts, dim=across_dim).type_as(zero_counts) / total_size
# NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned. # NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned.
metrics[name] = torch.ones_like(_apoz) - _apoz metrics[name] = torch.ones_like(_apoz) - _apoz
return metrics return metrics
...@@ -180,16 +178,15 @@ class APoZRankMetricsCalculator(MetricsCalculator): ...@@ -180,16 +178,15 @@ class APoZRankMetricsCalculator(MetricsCalculator):
class MeanRankMetricsCalculator(MetricsCalculator): class MeanRankMetricsCalculator(MetricsCalculator):
""" """
This metric simply concat the list of tensor on dim 0, and average on `dim`. The data value format is a two-element list [batch_number, batch_wise_activation_sum].
MeanRank pruner use this to calculate metric. This metric simply calculate the average on `self.dim`, then divide by the batch_number.
MeanRank pruner uses this to calculate metric.
""" """
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]: def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
metrics = {} metrics = {}
for name, tensor_list in data.items(): for name, (num, activation_sum) in data.items():
# NOTE: dim=0 means the batch dim is 0 keeped_dim = list(range(len(activation_sum.size()))) if self.dim is None else self.dim
activations = torch.cat(tensor_list, dim=0) across_dim = list(range(len(activation_sum.size())))
keeped_dim = list(range(len(activations.size()))) if self.dim is None else self.dim
across_dim = list(range(len(activations.size())))
[across_dim.pop(i) for i in reversed(keeped_dim)] [across_dim.pop(i) for i in reversed(keeped_dim)]
metrics[name] = torch.mean(activations, across_dim) metrics[name] = torch.mean(activation_sum, across_dim) / num
return metrics return metrics
...@@ -139,12 +139,12 @@ class PruningToolsTestCase(unittest.TestCase): ...@@ -139,12 +139,12 @@ class PruningToolsTestCase(unittest.TestCase):
# Test MultiDataNormMetricsCalculator # Test MultiDataNormMetricsCalculator
metrics_calculator = MultiDataNormMetricsCalculator(dim=0, p=1) metrics_calculator = MultiDataNormMetricsCalculator(dim=0, p=1)
data = { data = {
'1': [torch.ones(3, 3, 3), torch.ones(3, 3, 3) * 2], '1': [2, torch.ones(3, 3, 3) * 2],
'2': [torch.ones(4, 4), torch.ones(4, 4) * 2] '2': [2, torch.ones(4, 4) * 2]
} }
result = { result = {
'1': torch.ones(3) * 27, '1': torch.ones(3) * 18,
'2': torch.ones(4) * 12 '2': torch.ones(4) * 8
} }
metrics = metrics_calculator.calculate_metrics(data) metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items()) assert all(torch.equal(result[k], v) for k, v in metrics.items())
...@@ -152,12 +152,12 @@ class PruningToolsTestCase(unittest.TestCase): ...@@ -152,12 +152,12 @@ class PruningToolsTestCase(unittest.TestCase):
# Test APoZRankMetricsCalculator # Test APoZRankMetricsCalculator
metrics_calculator = APoZRankMetricsCalculator(dim=1) metrics_calculator = APoZRankMetricsCalculator(dim=1)
data = { data = {
'1': [torch.tensor([[1, 0], [0, 1]], dtype=torch.float32), torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)], '1': [2, torch.tensor([[1, 1], [1, 1]], dtype=torch.float32)],
'2': [torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.float32), torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)] '2': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
} }
result = { result = {
'1': torch.tensor([0.5, 0.5], dtype=torch.float32), '1': torch.tensor([0.5, 0.5], dtype=torch.float32),
'2': torch.tensor([0.25, 0.25, 0.5], dtype=torch.float32) '2': torch.tensor([1, 1, 0.75], dtype=torch.float32)
} }
metrics = metrics_calculator.calculate_metrics(data) metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items()) assert all(torch.equal(result[k], v) for k, v in metrics.items())
...@@ -165,12 +165,12 @@ class PruningToolsTestCase(unittest.TestCase): ...@@ -165,12 +165,12 @@ class PruningToolsTestCase(unittest.TestCase):
# Test MeanRankMetricsCalculator # Test MeanRankMetricsCalculator
metrics_calculator = MeanRankMetricsCalculator(dim=1) metrics_calculator = MeanRankMetricsCalculator(dim=1)
data = { data = {
'1': [torch.tensor([[1, 0], [0, 1]], dtype=torch.float32), torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)], '1': [2, torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)],
'2': [torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.float32), torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)] '2': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
} }
result = { result = {
'1': torch.tensor([0.5, 0.5], dtype=torch.float32), '1': torch.tensor([0.25, 0.25], dtype=torch.float32),
'2': torch.tensor([0.25, 0.25, 0.5], dtype=torch.float32) '2': torch.tensor([0, 0, 0.25], dtype=torch.float32)
} }
metrics = metrics_calculator.calculate_metrics(data) metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items()) assert all(torch.equal(result[k], v) for k, v in metrics.items())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment