Unverified Commit 41faab39 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fix lottery ticket (#2286)

parent 065d788c
...@@ -329,10 +329,6 @@ class LotteryTicketPruner(Pruner): ...@@ -329,10 +329,6 @@ class LotteryTicketPruner(Pruner):
reset_weights : bool reset_weights : bool
Whether reset weights and optimizer at the beginning of each round. Whether reset weights and optimizer at the beginning of each round.
""" """
super().__init__(model, config_list, optimizer)
self.curr_prune_iteration = None
self.prune_iterations = config_list[0]['prune_iterations']
# save init weights and optimizer # save init weights and optimizer
self.reset_weights = reset_weights self.reset_weights = reset_weights
if self.reset_weights: if self.reset_weights:
...@@ -344,6 +340,10 @@ class LotteryTicketPruner(Pruner): ...@@ -344,6 +340,10 @@ class LotteryTicketPruner(Pruner):
if lr_scheduler is not None: if lr_scheduler is not None:
self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict()) self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict())
super().__init__(model, config_list, optimizer)
self.curr_prune_iteration = None
self.prune_iterations = config_list[0]['prune_iterations']
def validate_config(self, model, config_list): def validate_config(self, model, config_list):
""" """
Parameters Parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment