Unverified Commit e457047c authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[retiarii] update debug info, and add license (#3438)

parent 539a7cd7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
from typing import Any, Callable
......@@ -99,7 +102,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameters': parameters,
'parameter_source': 'algorithm'
}
_logger.info('New trial sent: %s', new_trial)
_logger.debug('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_dumps(new_trial))
if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable
......@@ -109,21 +112,21 @@ class RetiariiAdvisor(MsgDispatcherBase):
send(CommandType.NoMoreTrialJobs, '')
def handle_request_trial_jobs(self, num_trials):
_logger.info('Request trial jobs: %s', num_trials)
_logger.debug('Request trial jobs: %s', num_trials)
if self.request_trial_jobs_callback is not None:
self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable
def handle_update_search_space(self, data):
_logger.info('Received search space: %s', data)
_logger.debug('Received search space: %s', data)
self.search_space = data
def handle_trial_end(self, data):
_logger.info('Trial end: %s', data)
_logger.debug('Trial end: %s', data)
self.trial_end_callback(json_loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
_logger.info('Metric reported: %s', data)
_logger.debug('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from typing import NewType, Any
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Iterable, List, Optional)
from .graph import Model
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import OrderedDict
from typing import Any, List, Union, Dict
import warnings
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, List, Optional, Tuple
from ...mutator import Mutator
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .interface import BaseOneShotTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
from typing import Any
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .darts import DartsTrainer
from .enas import EnasTrainer
from .proxyless import ProxylessTrainer
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Dict, List)
from . import debug_configs
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ..operation import TensorFlowOperation
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, List)
import torch
......@@ -369,7 +372,6 @@ class TensorOps(PyTorchOperation):
return TensorOpExceptions[self.type](output, inputs)
op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i+1]}' for i, (name, t, default) in enumerate(matched_args)])
print(args_str)
return f'{output} = {inputs[0]}.{op_name}({args_str})'
class TorchOps(PyTorchOperation):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
import functools
import inspect
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base import BaseStrategy
from .bruteforce import Random, GridSearch
from .evolution import RegularizedEvolution
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
from typing import List
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import itertools
import logging
......@@ -36,7 +39,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
history.add(selected)
break
if retry_count + 1 == retries:
_logger.info('Random generation has run out of patience. There is nothing to search. Exiting.')
_logger.debug('Random generation has run out of patience. There is nothing to search. Exiting.')
return
yield {key: value for key, value in zip(keys, selected)}
......@@ -58,7 +61,7 @@ class GridSearch(BaseStrategy):
def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in grid_generator(search_space, shuffle=self.shuffle):
_logger.info('New model created. Waiting for resource. %s', str(sample))
_logger.debug('New model created. Waiting for resource. %s', str(sample))
if query_available_resources() <= 0:
time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample))
......@@ -101,7 +104,7 @@ class Random(BaseStrategy):
model = base_model
for mutator in applied_mutators:
model = mutator.apply(model)
_logger.info('New model created. Applied mutators are: %s', str(applied_mutators))
_logger.debug('New model created. Applied mutators are: %s', str(applied_mutators))
submit_models(model)
else:
time.sleep(self._polling_interval)
......@@ -109,7 +112,7 @@ class Random(BaseStrategy):
_logger.info('Random search running in fixed size mode. Dedup: %s.', 'on' if self.dedup else 'off')
search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in random_generator(search_space, dedup=self.dedup):
_logger.info('New model created. Waiting for resource. %s', str(sample))
_logger.debug('New model created. Waiting for resource. %s', str(sample))
if query_available_resources() <= 0:
time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
import dataclasses
import logging
......@@ -122,7 +125,7 @@ class RegularizedEvolution(BaseStrategy):
break
def _submit_config(self, config, base_model, mutators):
_logger.info('Model submitted to running queue: %s', config)
_logger.debug('Model submitted to running queue: %s', config)
model = get_targeted_model(base_model, mutators, config)
submit_models(model)
self._running_models.append((config, model))
......@@ -138,7 +141,7 @@ class RegularizedEvolution(BaseStrategy):
metric = model.metric
if metric is not None:
individual = Individual(config, metric)
_logger.info('Individual created: %s', str(individual))
_logger.debug('Individual created: %s', str(individual))
self._population.append(individual)
if len(self._population) > self.population_size:
self._population.popleft()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import time
......@@ -55,7 +58,7 @@ class TPEStrategy(BaseStrategy):
avail_resource = query_available_resources()
if avail_resource > 0:
model = base_model
_logger.info('New model created. Applied mutators: %s', str(applied_mutators))
_logger.debug('New model created. Applied mutators: %s', str(applied_mutators))
self.tpe_sampler.generate_samples(self.model_id)
for mutator in applied_mutators:
mutator.bind_sampler(self.tpe_sampler)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
from typing import Dict, Any, List
from ..graph import Model
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Entrypoint for trials.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment