Unverified Commit bbcb1677 authored by Guoxin's avatar Guoxin Committed by GitHub
Browse files

fix SimulatedAnnealingPruner export mask issue (#2736)

parent 143c6615
......@@ -243,13 +243,11 @@ class SimulatedAnnealingPruner(Pruner):
_logger.info('current perturation magnitude:%s', magnitude)
while True:
perturbation = np.random.uniform(-magnitude,
magnitude, len(self.get_modules_wrapper()))
perturbation = np.random.uniform(-magnitude, magnitude, len(self.get_modules_wrapper()))
sparsities = np.clip(0, self._sparsities + perturbation, None)
_logger.debug("sparsities before rescalling:%s", sparsities)
sparsities = self._rescale_sparsities(
sparsities, target_sparsity=self._sparsity)
sparsities = self._rescale_sparsities(sparsities, target_sparsity=self._sparsity)
_logger.debug("sparsities after rescalling:%s", sparsities)
if sparsities is not None and sparsities[0] >= 0 and sparsities[-1] < 1:
......@@ -312,6 +310,8 @@ class SimulatedAnnealingPruner(Pruner):
# save the overall best masked model
self.bound_model = model_masked
# the ops with sparsity 0 are not included in this modules_wrapper
modules_wrapper_final = pruner.get_modules_wrapper()
break
# if not, accept with probability e^(-deltaE/current_temperature)
else:
......@@ -356,4 +356,8 @@ class SimulatedAnnealingPruner(Pruner):
if return_config_list:
return self._best_config_list
# This should be done only at the final stage,
# because the modules_wrapper with all the ops are used during the annealing process
self.modules_wrapper = modules_wrapper_final
return self.bound_model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment