"...resnet50_tensorflow.git" did not exist on "25b9330e1fa15f7d239b36822ca623e36b53fe59"
Unverified Commit aa98462d authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Dedup for evolution search (#5092)

parent bcc640c4
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import collections import collections
import dataclasses import dataclasses
import logging import logging
import random import random
import time import time
from typing import Deque
from nni.nas.execution import query_available_resources, submit_models from nni.nas.execution import query_available_resources, submit_models
from nni.nas.execution.common import ModelStatus from nni.nas.execution.common import Model, ModelStatus
from .base import BaseStrategy from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model, filter_model from .utils import dry_run_for_search_space, get_targeted_model, filter_model
...@@ -43,6 +46,10 @@ class RegularizedEvolution(BaseStrategy): ...@@ -43,6 +46,10 @@ class RegularizedEvolution(BaseStrategy):
The number of individuals that should participate in each tournament. Default: 25. The number of individuals that should participate in each tournament. Default: 25.
mutation_prob : float mutation_prob : float
Probability that mutation happens in each dim. Default: 0.05 Probability that mutation happens in each dim. Default: 0.05
dedup : bool
Do not try the same configuration twice. Default: true.
dedup_retries : int
If dedup is true, retry the same configuration up to dedup_retries times. Default: 500.
on_failure : str on_failure : str
Can be one of "ignore" and "worst". If "ignore", simply give up the model and find a new one. Can be one of "ignore" and "worst". If "ignore", simply give up the model and find a new one.
If "worst", mark the model as -inf (if maximize, inf if minimize), so that the algorithm "learns" to avoid such model. If "worst", mark the model as -inf (if maximize, inf if minimize), so that the algorithm "learns" to avoid such model.
...@@ -52,7 +59,7 @@ class RegularizedEvolution(BaseStrategy): ...@@ -52,7 +59,7 @@ class RegularizedEvolution(BaseStrategy):
""" """
def __init__(self, optimize_mode='maximize', population_size=100, sample_size=25, cycles=20000, def __init__(self, optimize_mode='maximize', population_size=100, sample_size=25, cycles=20000,
mutation_prob=0.05, on_failure='ignore', model_filter=None): mutation_prob=0.05, dedup=False, dedup_retries=500, on_failure='ignore', model_filter=None):
assert optimize_mode in ['maximize', 'minimize'] assert optimize_mode in ['maximize', 'minimize']
assert on_failure in ['ignore', 'worst'] assert on_failure in ['ignore', 'worst']
assert sample_size < population_size assert sample_size < population_size
...@@ -61,13 +68,16 @@ class RegularizedEvolution(BaseStrategy): ...@@ -61,13 +68,16 @@ class RegularizedEvolution(BaseStrategy):
self.sample_size = sample_size self.sample_size = sample_size
self.cycles = cycles self.cycles = cycles
self.mutation_prob = mutation_prob self.mutation_prob = mutation_prob
self.dedup = dedup
self.dedup_retries = dedup_retries
self.on_failure = on_failure self.on_failure = on_failure
self._worst = float('-inf') if self.optimize_mode == 'maximize' else float('inf') self._worst = float('-inf') if self.optimize_mode == 'maximize' else float('inf')
self._success_count = 0 self._success_count = 0
self._population = collections.deque() self._history_configs: list[str] = [] # for dedup. has to be a list because keys are non-hashable.
self._running_models = [] self._population: Deque[Individual] = collections.deque()
self._running_models: list[tuple[dict, Model]] = []
self._polling_interval = 2. self._polling_interval = 2.
self.filter = model_filter self.filter = model_filter
...@@ -95,6 +105,18 @@ class RegularizedEvolution(BaseStrategy): ...@@ -95,6 +105,18 @@ class RegularizedEvolution(BaseStrategy):
parent = min(samples, key=lambda sample: sample.y) parent = min(samples, key=lambda sample: sample.y)
return parent.x return parent.x
def repeat_until_new_config(self, generator):
if not self.dedup:
# Do nothing if not deduplicating
return generator()
for _ in range(self.dedup_retries):
config = generator()
if config not in self._history_configs:
return config
_logger.warning('Deduplication failed. Generating an arbitrary config.')
return generator()
def run(self, base_model, applied_mutators): def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators) search_space = dry_run_for_search_space(base_model, applied_mutators)
# Run the first population regardless concurrency # Run the first population regardless concurrency
...@@ -102,7 +124,7 @@ class RegularizedEvolution(BaseStrategy): ...@@ -102,7 +124,7 @@ class RegularizedEvolution(BaseStrategy):
while len(self._population) + len(self._running_models) <= self.population_size: while len(self._population) + len(self._running_models) <= self.population_size:
# try to submit new models # try to submit new models
while len(self._population) + len(self._running_models) < self.population_size: while len(self._population) + len(self._running_models) < self.population_size:
config = self.random(search_space) config = self.repeat_until_new_config(lambda: self.random(search_space))
self._submit_config(config, base_model, applied_mutators) self._submit_config(config, base_model, applied_mutators)
# collect results # collect results
self._move_succeeded_models_to_population() self._move_succeeded_models_to_population()
...@@ -117,7 +139,7 @@ class RegularizedEvolution(BaseStrategy): ...@@ -117,7 +139,7 @@ class RegularizedEvolution(BaseStrategy):
while self._success_count + len(self._running_models) <= self.cycles: while self._success_count + len(self._running_models) <= self.cycles:
# try to submit new models # try to submit new models
while query_available_resources() > 0 and self._success_count + len(self._running_models) < self.cycles: while query_available_resources() > 0 and self._success_count + len(self._running_models) < self.cycles:
config = self.mutate(self.best_parent(), search_space) config = self.repeat_until_new_config(lambda: self.mutate(self.best_parent(), search_space))
self._submit_config(config, base_model, applied_mutators) self._submit_config(config, base_model, applied_mutators)
# collect results # collect results
self._move_succeeded_models_to_population() self._move_succeeded_models_to_population()
...@@ -129,6 +151,7 @@ class RegularizedEvolution(BaseStrategy): ...@@ -129,6 +151,7 @@ class RegularizedEvolution(BaseStrategy):
def _submit_config(self, config, base_model, mutators): def _submit_config(self, config, base_model, mutators):
_logger.debug('Model submitted to running queue: %s', config) _logger.debug('Model submitted to running queue: %s', config)
self._history_configs.append(config)
model = get_targeted_model(base_model, mutators, config) model = get_targeted_model(base_model, mutators, config)
if not filter_model(self.filter, model): if not filter_model(self.filter, model):
if self.on_failure == "worst": if self.on_failure == "worst":
......
...@@ -136,6 +136,13 @@ def test_evolution(): ...@@ -136,6 +136,13 @@ def test_evolution():
wait_models(*engine.models) wait_models(*engine.models)
_reset_execution_engine() _reset_execution_engine()
evolution = strategy.RegularizedEvolution(population_size=5, sample_size=3, cycles=10, mutation_prob=0.5, dedup=True, on_failure='ignore')
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
evolution.run(*_get_model_and_mutators())
wait_models(*engine.models)
_reset_execution_engine()
evolution = strategy.RegularizedEvolution(population_size=5, sample_size=3, cycles=10, mutation_prob=0.5, on_failure='worst') evolution = strategy.RegularizedEvolution(population_size=5, sample_size=3, cycles=10, mutation_prob=0.5, on_failure='worst')
engine = MockExecutionEngine(failure_prob=0.4) engine = MockExecutionEngine(failure_prob=0.4)
_reset_execution_engine(engine) _reset_execution_engine(engine)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment