Unverified Commit 4dc8c032 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Make brute-force strategies budget aware (#3805)

parent 95f4c863
...@@ -8,7 +8,7 @@ import random ...@@ -8,7 +8,7 @@ import random
import time import time
from typing import Any, Dict, List from typing import Any, Dict, List
from .. import Sampler, submit_models, query_available_resources from .. import Sampler, submit_models, query_available_resources, budget_exhausted
from .base import BaseStrategy from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model from .utils import dry_run_for_search_space, get_targeted_model
...@@ -63,6 +63,8 @@ class GridSearch(BaseStrategy): ...@@ -63,6 +63,8 @@ class GridSearch(BaseStrategy):
for sample in grid_generator(search_space, shuffle=self.shuffle): for sample in grid_generator(search_space, shuffle=self.shuffle):
_logger.debug('New model created. Waiting for resource. %s', str(sample)) _logger.debug('New model created. Waiting for resource. %s', str(sample))
while query_available_resources() <= 0: while query_available_resources() <= 0:
if budget_exhausted():
return
time.sleep(self._polling_interval) time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample)) submit_models(get_targeted_model(base_model, applied_mutators, sample))
...@@ -106,6 +108,8 @@ class Random(BaseStrategy): ...@@ -106,6 +108,8 @@ class Random(BaseStrategy):
model = mutator.apply(model) model = mutator.apply(model)
_logger.debug('New model created. Applied mutators are: %s', str(applied_mutators)) _logger.debug('New model created. Applied mutators are: %s', str(applied_mutators))
submit_models(model) submit_models(model)
elif budget_exhausted():
break
else: else:
time.sleep(self._polling_interval) time.sleep(self._polling_interval)
else: else:
...@@ -114,5 +118,7 @@ class Random(BaseStrategy): ...@@ -114,5 +118,7 @@ class Random(BaseStrategy):
for sample in random_generator(search_space, dedup=self.dedup): for sample in random_generator(search_space, dedup=self.dedup):
_logger.debug('New model created. Waiting for resource. %s', str(sample)) _logger.debug('New model created. Waiting for resource. %s', str(sample))
while query_available_resources() <= 0: while query_available_resources() <= 0:
if budget_exhausted():
return
time.sleep(self._polling_interval) time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample)) submit_models(get_targeted_model(base_model, applied_mutators, sample))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment