Unverified Commit aa98462d authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Dedup for evolution search (#5092)

parent bcc640c4
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import collections
import dataclasses
import logging
import random
import time
from typing import Deque
from nni.nas.execution import query_available_resources, submit_models
from nni.nas.execution.common import ModelStatus
from nni.nas.execution.common import Model, ModelStatus
from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model, filter_model
......@@ -43,6 +46,10 @@ class RegularizedEvolution(BaseStrategy):
The number of individuals that should participate in each tournament. Default: 25.
mutation_prob : float
Probability that mutation happens in each dim. Default: 0.05
dedup : bool
Do not try the same configuration twice. Default: true.
dedup_retries : int
If dedup is true, retry the same configuration up to dedup_retries times. Default: 500.
on_failure : str
Can be one of "ignore" and "worst". If "ignore", simply give up the model and find a new one.
If "worst", mark the model as -inf (if maximize, inf if minimize), so that the algorithm "learns" to avoid such model.
......@@ -52,7 +59,7 @@ class RegularizedEvolution(BaseStrategy):
"""
def __init__(self, optimize_mode='maximize', population_size=100, sample_size=25, cycles=20000,
mutation_prob=0.05, on_failure='ignore', model_filter=None):
mutation_prob=0.05, dedup=False, dedup_retries=500, on_failure='ignore', model_filter=None):
assert optimize_mode in ['maximize', 'minimize']
assert on_failure in ['ignore', 'worst']
assert sample_size < population_size
......@@ -61,13 +68,16 @@ class RegularizedEvolution(BaseStrategy):
self.sample_size = sample_size
self.cycles = cycles
self.mutation_prob = mutation_prob
self.dedup = dedup
self.dedup_retries = dedup_retries
self.on_failure = on_failure
self._worst = float('-inf') if self.optimize_mode == 'maximize' else float('inf')
self._success_count = 0
self._population = collections.deque()
self._running_models = []
self._history_configs: list[str] = [] # for dedup. has to be a list because keys are non-hashable.
self._population: Deque[Individual] = collections.deque()
self._running_models: list[tuple[dict, Model]] = []
self._polling_interval = 2.
self.filter = model_filter
......@@ -95,6 +105,18 @@ class RegularizedEvolution(BaseStrategy):
parent = min(samples, key=lambda sample: sample.y)
return parent.x
def repeat_until_new_config(self, generator):
if not self.dedup:
# Do nothing if not deduplicating
return generator()
for _ in range(self.dedup_retries):
config = generator()
if config not in self._history_configs:
return config
_logger.warning('Deduplication failed. Generating an arbitrary config.')
return generator()
def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
# Run the first population regardless concurrency
......@@ -102,7 +124,7 @@ class RegularizedEvolution(BaseStrategy):
while len(self._population) + len(self._running_models) <= self.population_size:
# try to submit new models
while len(self._population) + len(self._running_models) < self.population_size:
config = self.random(search_space)
config = self.repeat_until_new_config(lambda: self.random(search_space))
self._submit_config(config, base_model, applied_mutators)
# collect results
self._move_succeeded_models_to_population()
......@@ -117,7 +139,7 @@ class RegularizedEvolution(BaseStrategy):
while self._success_count + len(self._running_models) <= self.cycles:
# try to submit new models
while query_available_resources() > 0 and self._success_count + len(self._running_models) < self.cycles:
config = self.mutate(self.best_parent(), search_space)
config = self.repeat_until_new_config(lambda: self.mutate(self.best_parent(), search_space))
self._submit_config(config, base_model, applied_mutators)
# collect results
self._move_succeeded_models_to_population()
......@@ -129,6 +151,7 @@ class RegularizedEvolution(BaseStrategy):
def _submit_config(self, config, base_model, mutators):
_logger.debug('Model submitted to running queue: %s', config)
self._history_configs.append(config)
model = get_targeted_model(base_model, mutators, config)
if not filter_model(self.filter, model):
if self.on_failure == "worst":
......
......@@ -136,6 +136,13 @@ def test_evolution():
wait_models(*engine.models)
_reset_execution_engine()
evolution = strategy.RegularizedEvolution(population_size=5, sample_size=3, cycles=10, mutation_prob=0.5, dedup=True, on_failure='ignore')
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
evolution.run(*_get_model_and_mutators())
wait_models(*engine.models)
_reset_execution_engine()
evolution = strategy.RegularizedEvolution(population_size=5, sample_size=3, cycles=10, mutation_prob=0.5, on_failure='worst')
engine = MockExecutionEngine(failure_prob=0.4)
_reset_execution_engine(engine)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment