"ts/webui/src/static/style/App.scss" did not exist on "1bfc7acf3ed4f910b0db4e436a216bc6663545aa"
Unverified Commit 2772751d authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix SA task generator bug (#4457)

parent 31f11f51
...@@ -217,8 +217,8 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -217,8 +217,8 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
self._temp_config_list = [] self._temp_config_list = []
self._temp_sparsity_list = [] self._temp_sparsity_list = []
for config in self.target_sparsity_list: for config in self.target_sparsity_list:
sparsity_config, sparsity = self._init_config_sparsity(config) sparsity_config_list, sparsity = self._init_config_sparsity(config)
self._temp_config_list.extend(sparsity_config) self._temp_config_list.extend(sparsity_config_list)
self._temp_sparsity_list.append(sparsity) self._temp_sparsity_list.append(sparsity)
def _init_config_sparsity(self, config: Dict) -> Tuple[List[Dict], List]: def _init_config_sparsity(self, config: Dict) -> Tuple[List[Dict], List]:
...@@ -227,7 +227,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -227,7 +227,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
op_names = config['op_names'] op_names = config['op_names']
if target_sparsity == 0: if target_sparsity == 0:
return [], [] sparsity_config_list = [deepcopy(config) for i in range(len(op_names))]
for sparsity_config, op_name in zip(sparsity_config_list, op_names):
sparsity_config.update({'total_sparsity': 0, 'op_names': [op_name]})
return sparsity_config_list, []
low_limit = 0 low_limit = 0
while True: while True:
...@@ -266,7 +269,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -266,7 +269,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
sparsity = sorted(sparsity) sparsity = sorted(sparsity)
op_names = [k for k, _ in sorted(self.weights_numel.items(), key=lambda item: item[1]) if k in config['op_names']] op_names = [k for k, _ in sorted(self.weights_numel.items(), key=lambda item: item[1]) if k in config['op_names']]
assert len(sparsity) == len(op_names) assert len(sparsity) == len(op_names)
return [{'total_sparsity': sparsity, 'op_names': [op_name]} for sparsity, op_name in zip(sparsity, op_names)] sub_temp_config_list = [deepcopy(config) for i in range(len(op_names))]
for temp_config, sp, op_name in zip(sub_temp_config_list, sparsity, op_names):
temp_config.update({'total_sparsity': sp, 'op_names': [op_name]})
return sub_temp_config_list
def _update_with_perturbations(self): def _update_with_perturbations(self):
self._temp_config_list = [] self._temp_config_list = []
...@@ -275,6 +281,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -275,6 +281,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude
for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list): for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list):
if len(current_sparsity) == 0: if len(current_sparsity) == 0:
self._temp_config_list.extend(deepcopy(config))
self._temp_sparsity_list.append([]) self._temp_sparsity_list.append([])
continue continue
while True: while True:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment