Unverified Commit e457047c authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[retiarii] update debug info, and add license (#3438)

parent 539a7cd7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging import logging
import os import os
from typing import Any, Callable from typing import Any, Callable
...@@ -99,7 +102,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -99,7 +102,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameters': parameters, 'parameters': parameters,
'parameter_source': 'algorithm' 'parameter_source': 'algorithm'
} }
_logger.info('New trial sent: %s', new_trial) _logger.debug('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_dumps(new_trial)) send(CommandType.NewTrialJob, json_dumps(new_trial))
if self.send_trial_callback is not None: if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable self.send_trial_callback(parameters) # pylint: disable=not-callable
...@@ -109,21 +112,21 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -109,21 +112,21 @@ class RetiariiAdvisor(MsgDispatcherBase):
send(CommandType.NoMoreTrialJobs, '') send(CommandType.NoMoreTrialJobs, '')
def handle_request_trial_jobs(self, num_trials): def handle_request_trial_jobs(self, num_trials):
_logger.info('Request trial jobs: %s', num_trials) _logger.debug('Request trial jobs: %s', num_trials)
if self.request_trial_jobs_callback is not None: if self.request_trial_jobs_callback is not None:
self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
_logger.info('Received search space: %s', data) _logger.debug('Received search space: %s', data)
self.search_space = data self.search_space = data
def handle_trial_end(self, data): def handle_trial_end(self, data):
_logger.info('Trial end: %s', data) _logger.debug('Trial end: %s', data)
self.trial_end_callback(json_loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable self.trial_end_callback(json_loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED') data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
_logger.info('Metric reported: %s', data) _logger.debug('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported') raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL: elif data['type'] == MetricType.PERIODICAL:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json import json
from typing import NewType, Any from typing import NewType, Any
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Iterable, List, Optional) from typing import (Any, Iterable, List, Optional)
from .graph import Model from .graph import Model
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import OrderedDict from collections import OrderedDict
from typing import Any, List, Union, Dict from typing import Any, List, Union, Dict
import warnings import warnings
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
from ...mutator import Mutator from ...mutator import Mutator
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch import torch
import torch.nn as nn import torch.nn as nn
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .interface import BaseOneShotTrainer from .interface import BaseOneShotTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc import abc
from typing import Any from typing import Any
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .darts import DartsTrainer from .darts import DartsTrainer
from .enas import EnasTrainer from .enas import EnasTrainer
from .proxyless import ProxylessTrainer from .proxyless import ProxylessTrainer
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Dict, List) from typing import (Any, Dict, List)
from . import debug_configs from . import debug_configs
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ..operation import TensorFlowOperation from ..operation import TensorFlowOperation
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, List) from typing import (Any, List)
import torch import torch
...@@ -369,7 +372,6 @@ class TensorOps(PyTorchOperation): ...@@ -369,7 +372,6 @@ class TensorOps(PyTorchOperation):
return TensorOpExceptions[self.type](output, inputs) return TensorOpExceptions[self.type](output, inputs)
op_name = self.type.split('::')[-1] op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i+1]}' for i, (name, t, default) in enumerate(matched_args)]) args_str = ', '.join([f'{name}={inputs[i+1]}' for i, (name, t, default) in enumerate(matched_args)])
print(args_str)
return f'{output} = {inputs[0]}.{op_name}({args_str})' return f'{output} = {inputs[0]}.{op_name}({args_str})'
class TorchOps(PyTorchOperation): class TorchOps(PyTorchOperation):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc import abc
import functools import functools
import inspect import inspect
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base import BaseStrategy from .base import BaseStrategy
from .bruteforce import Random, GridSearch from .bruteforce import Random, GridSearch
from .evolution import RegularizedEvolution from .evolution import RegularizedEvolution
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc import abc
from typing import List from typing import List
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy import copy
import itertools import itertools
import logging import logging
...@@ -36,7 +39,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500 ...@@ -36,7 +39,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
history.add(selected) history.add(selected)
break break
if retry_count + 1 == retries: if retry_count + 1 == retries:
_logger.info('Random generation has run out of patience. There is nothing to search. Exiting.') _logger.debug('Random generation has run out of patience. There is nothing to search. Exiting.')
return return
yield {key: value for key, value in zip(keys, selected)} yield {key: value for key, value in zip(keys, selected)}
...@@ -58,7 +61,7 @@ class GridSearch(BaseStrategy): ...@@ -58,7 +61,7 @@ class GridSearch(BaseStrategy):
def run(self, base_model, applied_mutators): def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators) search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in grid_generator(search_space, shuffle=self.shuffle): for sample in grid_generator(search_space, shuffle=self.shuffle):
_logger.info('New model created. Waiting for resource. %s', str(sample)) _logger.debug('New model created. Waiting for resource. %s', str(sample))
if query_available_resources() <= 0: if query_available_resources() <= 0:
time.sleep(self._polling_interval) time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample)) submit_models(get_targeted_model(base_model, applied_mutators, sample))
...@@ -101,7 +104,7 @@ class Random(BaseStrategy): ...@@ -101,7 +104,7 @@ class Random(BaseStrategy):
model = base_model model = base_model
for mutator in applied_mutators: for mutator in applied_mutators:
model = mutator.apply(model) model = mutator.apply(model)
_logger.info('New model created. Applied mutators are: %s', str(applied_mutators)) _logger.debug('New model created. Applied mutators are: %s', str(applied_mutators))
submit_models(model) submit_models(model)
else: else:
time.sleep(self._polling_interval) time.sleep(self._polling_interval)
...@@ -109,7 +112,7 @@ class Random(BaseStrategy): ...@@ -109,7 +112,7 @@ class Random(BaseStrategy):
_logger.info('Random search running in fixed size mode. Dedup: %s.', 'on' if self.dedup else 'off') _logger.info('Random search running in fixed size mode. Dedup: %s.', 'on' if self.dedup else 'off')
search_space = dry_run_for_search_space(base_model, applied_mutators) search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in random_generator(search_space, dedup=self.dedup): for sample in random_generator(search_space, dedup=self.dedup):
_logger.info('New model created. Waiting for resource. %s', str(sample)) _logger.debug('New model created. Waiting for resource. %s', str(sample))
if query_available_resources() <= 0: if query_available_resources() <= 0:
time.sleep(self._polling_interval) time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample)) submit_models(get_targeted_model(base_model, applied_mutators, sample))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections import collections
import dataclasses import dataclasses
import logging import logging
...@@ -122,7 +125,7 @@ class RegularizedEvolution(BaseStrategy): ...@@ -122,7 +125,7 @@ class RegularizedEvolution(BaseStrategy):
break break
def _submit_config(self, config, base_model, mutators): def _submit_config(self, config, base_model, mutators):
_logger.info('Model submitted to running queue: %s', config) _logger.debug('Model submitted to running queue: %s', config)
model = get_targeted_model(base_model, mutators, config) model = get_targeted_model(base_model, mutators, config)
submit_models(model) submit_models(model)
self._running_models.append((config, model)) self._running_models.append((config, model))
...@@ -138,7 +141,7 @@ class RegularizedEvolution(BaseStrategy): ...@@ -138,7 +141,7 @@ class RegularizedEvolution(BaseStrategy):
metric = model.metric metric = model.metric
if metric is not None: if metric is not None:
individual = Individual(config, metric) individual = Individual(config, metric)
_logger.info('Individual created: %s', str(individual)) _logger.debug('Individual created: %s', str(individual))
self._population.append(individual) self._population.append(individual)
if len(self._population) > self.population_size: if len(self._population) > self.population_size:
self._population.popleft() self._population.popleft()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging import logging
import time import time
...@@ -55,7 +58,7 @@ class TPEStrategy(BaseStrategy): ...@@ -55,7 +58,7 @@ class TPEStrategy(BaseStrategy):
avail_resource = query_available_resources() avail_resource = query_available_resources()
if avail_resource > 0: if avail_resource > 0:
model = base_model model = base_model
_logger.info('New model created. Applied mutators: %s', str(applied_mutators)) _logger.debug('New model created. Applied mutators: %s', str(applied_mutators))
self.tpe_sampler.generate_samples(self.model_id) self.tpe_sampler.generate_samples(self.model_id)
for mutator in applied_mutators: for mutator in applied_mutators:
mutator.bind_sampler(self.tpe_sampler) mutator.bind_sampler(self.tpe_sampler)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections import collections
from typing import Dict, Any, List from typing import Dict, Any, List
from ..graph import Model from ..graph import Model
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
""" """
Entrypoint for trials. Entrypoint for trials.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment