Unverified Commit 4f66d0c1 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #229 from microsoft/master

merge master
parents 4132f620 049634f7
{ {
"alpha": { "alpha": {
"_type": "quniform", "_type": "quniform",
"_value": [1.0, 2.0, 0.1] "_value": [1.0, 2.0, 0.05]
}, },
"beta": { "beta": {
"_type": "quniform", "_type": "quniform",
"_value": [1.0, 1.5, 0.1] "_value": [1.0, 1.5, 0.05]
}, },
"gamma": { "gamma": {
"_type": "quniform", "_type": "quniform",
"_value": [1.0, 1.5, 0.1] "_value": [1.0, 1.5, 0.05]
} }
} }
...@@ -14,11 +14,11 @@ class FixedProductTuner(GridSearchTuner): ...@@ -14,11 +14,11 @@ class FixedProductTuner(GridSearchTuner):
super().__init__() super().__init__()
self.product = product self.product = product
def expand_parameters(self, para): def _expand_parameters(self, para):
""" """
Filter out all qualified parameters Filter out all qualified parameters
""" """
para = super().expand_parameters(para) para = super()._expand_parameters(para)
if all([key in para[0] for key in ["alpha", "beta", "gamma"]]): # if this is an interested set if all([key in para[0] for key in ["alpha", "beta", "gamma"]]): # if this is an interested set
ret_para = [] ret_para = []
for p in para: for p in para:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
'use strict'; 'use strict';
import * as assert from 'assert'; import * as assert from 'assert';
import * as JSON5 from 'json5';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import * as component from '../common/component'; import * as component from '../common/component';
...@@ -131,7 +132,7 @@ class NNIDataStore implements DataStore { ...@@ -131,7 +132,7 @@ class NNIDataStore implements DataStore {
} }
public async storeMetricData(trialJobId: string, data: string): Promise<void> { public async storeMetricData(trialJobId: string, data: string): Promise<void> {
const metrics: MetricData = JSON.parse(data); const metrics: MetricData = JSON5.parse(data);
// REQUEST_PARAMETER is used to request new parameters for multiphase trial job, // REQUEST_PARAMETER is used to request new parameters for multiphase trial job,
// it is not metrics, so it is skipped here. // it is not metrics, so it is skipped here.
if (metrics.type === 'REQUEST_PARAMETER') { if (metrics.type === 'REQUEST_PARAMETER') {
...@@ -140,7 +141,7 @@ class NNIDataStore implements DataStore { ...@@ -140,7 +141,7 @@ class NNIDataStore implements DataStore {
} }
assert(trialJobId === metrics.trial_job_id); assert(trialJobId === metrics.trial_job_id);
try { try {
await this.db.storeMetricData(trialJobId, JSON.stringify({ await this.db.storeMetricData(trialJobId, JSON5.stringify({
trialJobId: metrics.trial_job_id, trialJobId: metrics.trial_job_id,
parameterId: metrics.parameter_id, parameterId: metrics.parameter_id,
type: metrics.type, type: metrics.type,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import * as assert from 'assert'; import * as assert from 'assert';
import * as fs from 'fs'; import * as fs from 'fs';
import * as JSON5 from 'json5';
import * as path from 'path'; import * as path from 'path';
import * as sqlite3 from 'sqlite3'; import * as sqlite3 from 'sqlite3';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
...@@ -202,10 +203,10 @@ class SqlDB implements Database { ...@@ -202,10 +203,10 @@ class SqlDB implements Database {
public storeMetricData(trialJobId: string, data: string): Promise<void> { public storeMetricData(trialJobId: string, data: string): Promise<void> {
const sql: string = 'insert into MetricData values (?,?,?,?,?,?)'; const sql: string = 'insert into MetricData values (?,?,?,?,?,?)';
const json: MetricDataRecord = JSON.parse(data); const json: MetricDataRecord = JSON5.parse(data);
const args: any[] = [Date.now(), json.trialJobId, json.parameterId, json.type, json.sequence, JSON.stringify(json.data)]; const args: any[] = [Date.now(), json.trialJobId, json.parameterId, json.type, json.sequence, JSON5.stringify(json.data)];
this.log.trace(`storeMetricData: SQL: ${sql}, args: ${JSON.stringify(args)}`); this.log.trace(`storeMetricData: SQL: ${sql}, args: ${JSON5.stringify(args)}`);
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
this.db.run(sql, args, (err: Error | null) => { this.resolve(deferred, err); }); this.db.run(sql, args, (err: Error | null) => { this.resolve(deferred, err); });
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"express": "^4.16.3", "express": "^4.16.3",
"express-joi-validator": "^2.0.0", "express-joi-validator": "^2.0.0",
"js-base64": "^2.4.9", "js-base64": "^2.4.9",
"json5": "^2.1.1",
"kubernetes-client": "^6.5.0", "kubernetes-client": "^6.5.0",
"rx": "^4.1.0", "rx": "^4.1.0",
"sqlite3": "^4.0.2", "sqlite3": "^4.0.2",
...@@ -34,6 +35,7 @@ ...@@ -34,6 +35,7 @@
"@types/express": "^4.16.0", "@types/express": "^4.16.0",
"@types/glob": "^7.1.1", "@types/glob": "^7.1.1",
"@types/js-base64": "^2.3.1", "@types/js-base64": "^2.3.1",
"@types/json5": "^0.0.30",
"@types/mocha": "^5.2.5", "@types/mocha": "^5.2.5",
"@types/node": "10.12.18", "@types/node": "10.12.18",
"@types/request": "^2.47.1", "@types/request": "^2.47.1",
......
...@@ -151,18 +151,20 @@ abstract class PAITrainingService implements TrainingService { ...@@ -151,18 +151,20 @@ abstract class PAITrainingService implements TrainingService {
public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> { public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
const deferred: Deferred<void> = new Deferred<void>();
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
this.log.error(`cancelTrialJob: trial job id ${trialJobId} not found`); return Promise.reject(new Error(`cancelTrialJob: trial job id ${trialJobId} not found`));
return Promise.reject();
} }
if (this.paiClusterConfig === undefined) { if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized'); return Promise.reject(new Error('PAI Cluster config is not initialized'));
} }
if (this.paiToken === undefined) { if (this.paiToken === undefined) {
throw new Error('PAI token is not initialized'); return Promise.reject(new Error('PAI token is not initialized'));
}
if (trialJobDetail.status === 'UNKNOWN') {
trialJobDetail.status = 'USER_CANCELED';
return Promise.resolve();
} }
const stopJobRequest: request.Options = { const stopJobRequest: request.Options = {
...@@ -179,6 +181,7 @@ abstract class PAITrainingService implements TrainingService { ...@@ -179,6 +181,7 @@ abstract class PAITrainingService implements TrainingService {
// Set trialjobDetail's early stopped field, to mark the job's cancellation source // Set trialjobDetail's early stopped field, to mark the job's cancellation source
trialJobDetail.isEarlyStopped = isEarlyStopped; trialJobDetail.isEarlyStopped = isEarlyStopped;
const deferred: Deferred<void> = new Deferred<void>();
request(stopJobRequest, (error: Error, response: request.Response, body: any) => { request(stopJobRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 400) { if ((error !== undefined && error !== null) || response.statusCode >= 400) {
......
...@@ -277,6 +277,12 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -277,6 +277,12 @@ class RemoteMachineTrainingService implements TrainingService {
throw new Error(`Invalid job id ${trialJobId}, cannot find ssh client`); throw new Error(`Invalid job id ${trialJobId}, cannot find ssh client`);
} }
if (trialJob.status === 'UNKNOWN') {
this.releaseTrialSSHClient(trialJob);
trialJob.status = 'USER_CANCELED';
return
}
const jobpidPath: string = this.getJobPidPath(trialJob.id); const jobpidPath: string = this.getJobPidPath(trialJob.id);
try { try {
// Mark the toEarlyStop tag here // Mark the toEarlyStop tag here
......
...@@ -157,6 +157,10 @@ ...@@ -157,6 +157,10 @@
version "7.0.3" version "7.0.3"
resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.3.tgz#bdfd69d61e464dcc81b25159c270d75a73c1a636" resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.3.tgz#bdfd69d61e464dcc81b25159c270d75a73c1a636"
"@types/json5@^0.0.30":
version "0.0.30"
resolved "https://registry.yarnpkg.com/@types/json5/-/json5-0.0.30.tgz#44cb52f32a809734ca562e685c6473b5754a7818"
"@types/mime@*": "@types/mime@*":
version "2.0.0" version "2.0.0"
resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.0.tgz#5a7306e367c539b9f6543499de8dd519fac37a8b" resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.0.tgz#5a7306e367c539b9f6543499de8dd519fac37a8b"
...@@ -1840,9 +1844,9 @@ growl@1.10.5: ...@@ -1840,9 +1844,9 @@ growl@1.10.5:
version "1.10.5" version "1.10.5"
resolved "https://registry.yarnpkg.com/growl/-/growl-1.10.5.tgz#f2735dc2283674fa67478b10181059355c369e5e" resolved "https://registry.yarnpkg.com/growl/-/growl-1.10.5.tgz#f2735dc2283674fa67478b10181059355c369e5e"
handlebars@^4.0.11, handlebars@^4.3.0: handlebars@^4.0.11, handlebars@^4.5.3:
version "4.5.3" version "4.7.2"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.5.3.tgz#5cf75bd8714f7605713511a56be7c349becb0482" resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.7.2.tgz#01127b3840156a0927058779482031afe0e730d7"
dependencies: dependencies:
neo-async "^2.6.0" neo-async "^2.6.0"
optimist "^0.6.1" optimist "^0.6.1"
...@@ -2371,6 +2375,12 @@ json-stringify-safe@~5.0.1: ...@@ -2371,6 +2375,12 @@ json-stringify-safe@~5.0.1:
version "5.0.1" version "5.0.1"
resolved "https://registry.yarnpkg.com/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz#1296a2d58fd45f19a0f6ce01d65701e2c735b6eb" resolved "https://registry.yarnpkg.com/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz#1296a2d58fd45f19a0f6ce01d65701e2c735b6eb"
json5@^2.1.1:
version "2.1.1"
resolved "https://registry.yarnpkg.com/json5/-/json5-2.1.1.tgz#81b6cb04e9ba496f1c7005d07b4368a2638f90b6"
dependencies:
minimist "^1.2.0"
jsonparse@^1.2.0: jsonparse@^1.2.0:
version "1.3.1" version "1.3.1"
resolved "https://registry.yarnpkg.com/jsonparse/-/jsonparse-1.3.1.tgz#3f4dae4a91fac315f71062f8521cc239f1366280" resolved "https://registry.yarnpkg.com/jsonparse/-/jsonparse-1.3.1.tgz#3f4dae4a91fac315f71062f8521cc239f1366280"
......
...@@ -102,6 +102,9 @@ class GridSearchTuner(Tuner): ...@@ -102,6 +102,9 @@ class GridSearchTuner(Tuner):
""" """
Parse type of randint parameter and return a list Parse type of randint parameter and return a list
""" """
if param_value[0] >= param_value[1]:
raise ValueError("Randint should contain at least 1 candidate, but [%s, %s) contains none.",
param_value[0], param_value[1])
return np.arange(param_value[0], param_value[1]).tolist() return np.arange(param_value[0], param_value[1]).tolist()
def _expand_parameters(self, para): def _expand_parameters(self, para):
......
...@@ -118,6 +118,8 @@ def json2vals(in_x, vals, out_y, name=NodeType.ROOT): ...@@ -118,6 +118,8 @@ def json2vals(in_x, vals, out_y, name=NodeType.ROOT):
vals[NodeType.VALUE], vals[NodeType.VALUE],
out_y, out_y,
name=name + '[%d]' % _index) name=name + '[%d]' % _index)
if _type == 'randint':
out_y[name] -= in_x[NodeType.VALUE][0]
else: else:
for key in in_x.keys(): for key in in_x.keys():
json2vals(in_x[key], vals[key], out_y, json2vals(in_x[key], vals[key], out_y,
......
...@@ -11,7 +11,7 @@ from .msg_dispatcher_base import MsgDispatcherBase ...@@ -11,7 +11,7 @@ from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult from .assessor import AssessResult
from .common import multi_thread_enabled, multi_phase_enabled from .common import multi_thread_enabled, multi_phase_enabled
from .env_vars import dispatcher_env_vars from .env_vars import dispatcher_env_vars
from .utils import MetricType from .utils import MetricType, to_json
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -62,7 +62,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p ...@@ -62,7 +62,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p
ret['parameter_index'] = parameter_index ret['parameter_index'] = parameter_index
else: else:
ret['parameter_index'] = 0 ret['parameter_index'] = 0
return json_tricks.dumps(ret) return to_json(ret)
class MsgDispatcher(MsgDispatcherBase): class MsgDispatcher(MsgDispatcherBase):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import RegularizedDartsMutator, RegularizedMutatorParallel, DartsDiscreteMutator
from .trainer import CdartsTrainer
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from apex.parallel import DistributedDataParallel # pylint: disable=import-error
from nni.nas.pytorch.darts import DartsMutator # pylint: disable=wrong-import-order
from nni.nas.pytorch.mutables import LayerChoice # pylint: disable=wrong-import-order
from nni.nas.pytorch.mutator import Mutator # pylint: disable=wrong-import-order
class RegularizedDartsMutator(DartsMutator):
"""
This is :class:`~nni.nas.pytorch.darts.DartsMutator` basically, with two differences.
1. Choices can be cut (bypassed). This is done by ``cut_choices``. Cutted choices will not be used in
forward pass and thus consumes no memory.
2. Regularization on choices, to prevent the mutator from overfitting on some choices.
"""
def reset(self):
"""
Warnings
--------
Renamed :func:`~reset_with_loss` to return regularization loss on reset.
"""
raise ValueError("You should probably call `reset_with_loss`.")
def cut_choices(self, cut_num=2):
"""
Cut the choices with the smallest weights.
``cut_num`` should be the accumulative number of cutting, e.g., if first time cutting
is 2, the second time should be 4 to cut another two.
Parameters
----------
cut_num : int
Number of choices to cut, so far.
Warnings
--------
Though the parameters are set to :math:`-\infty` to be bypassed, they will still receive gradient of 0,
which introduced ``nan`` problem when calling ``optimizer.step()``. To solve this issue, a simple way is to
reset nan to :math:`-\infty` each time after the parameters are updated.
"""
# `cut_choices` is implemented but not used in current implementation of CdartsTrainer
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
_, idx = torch.topk(-self.choices[mutable.key], cut_num)
with torch.no_grad():
for i in idx:
self.choices[mutable.key][i] = -float("inf")
def reset_with_loss(self):
"""
Resample and return loss. If loss is 0, to avoid device issue, it will return ``None``.
Currently loss penalty are proportional to the L1-norm of parameters corresponding
to modules if their type name contains certain substrings. These substrings include: ``poolwithoutbn``,
``identity``, ``dilconv``.
"""
self._cache, reg_loss = self.sample_search()
return reg_loss
def sample_search(self):
result = super().sample_search()
loss = []
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
def need_reg(choice):
return any(t in str(type(choice)).lower() for t in ["poolwithoutbn", "identity", "dilconv"])
for i, choice in enumerate(mutable.choices):
if need_reg(choice):
norm = torch.abs(self.choices[mutable.key][i])
if norm < 1E10:
loss.append(norm)
if not loss:
return result, None
return result, sum(loss)
def export(self, logger=None):
"""
Export an architecture with logger. Genotype will be printed with logger.
Returns
-------
dict
A mapping from mutable keys to decisions.
"""
result = self.sample_final()
if hasattr(self.model, "plot_genotype") and logger is not None:
genotypes = self.model.plot_genotype(result, logger)
return result, genotypes
class RegularizedMutatorParallel(DistributedDataParallel):
"""
Parallelize :class:`~RegularizedDartsMutator`.
This makes :func:`~RegularizedDartsMutator.reset_with_loss` method parallelized,
also allowing :func:`~RegularizedDartsMutator.cut_choices` and :func:`~RegularizedDartsMutator.export`
to be easily accessible.
"""
def reset_with_loss(self):
"""
Parallelized :func:`~RegularizedDartsMutator.reset_with_loss`.
"""
result = self.module.reset_with_loss()
self.callback_queued = False
return result
def cut_choices(self, *args, **kwargs):
"""
Parallelized :func:`~RegularizedDartsMutator.cut_choices`.
"""
self.module.cut_choices(*args, **kwargs)
def export(self, logger):
"""
Parallelized :func:`~RegularizedDartsMutator.export`.
"""
return self.module.export(logger)
class DartsDiscreteMutator(Mutator):
"""
A mutator that applies the final sampling result of a parent mutator on another model to train.
"""
def __init__(self, model, parent_mutator):
"""
Initialization.
Parameters
----------
model : nn.Module
The model to apply the mutator.
parent_mutator : Mutator
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
"""
super().__init__(model)
self.__dict__["parent_mutator"] = parent_mutator # avoid parameters to be included
def sample_search(self):
return self.parent_mutator.sample_final()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import apex # pylint: disable=import-error
from apex.parallel import DistributedDataParallel # pylint: disable=import-error
from nni.nas.pytorch.cdarts import RegularizedDartsMutator, RegularizedMutatorParallel, DartsDiscreteMutator # pylint: disable=wrong-import-order
from nni.nas.pytorch.utils import AverageMeterGroup # pylint: disable=wrong-import-order
from .utils import CyclicIterator, TorchTensorEncoder, accuracy, reduce_metrics
PHASE_SMALL = "small"
PHASE_LARGE = "large"
class InteractiveKLLoss(nn.Module):
def __init__(self, temperature):
super().__init__()
self.temperature = temperature
# self.kl_loss = nn.KLDivLoss(reduction = 'batchmean')
self.kl_loss = nn.KLDivLoss()
def forward(self, student, teacher):
return self.kl_loss(F.log_softmax(student / self.temperature, dim=1),
F.softmax(teacher / self.temperature, dim=1))
class CdartsTrainer(object):
def __init__(self, model_small, model_large, criterion, loaders, samplers, logger=None,
regular_coeff=5, regular_ratio=0.2, warmup_epochs=2, fix_head=True,
epochs=32, steps_per_epoch=None, loss_alpha=2, loss_T=2, distributed=True,
log_frequency=10, grad_clip=5.0, interactive_type='kl', output_path='./outputs',
w_lr=0.2, w_momentum=0.9, w_weight_decay=3e-4, alpha_lr=0.2, alpha_weight_decay=1e-4,
nasnet_lr=0.2, local_rank=0, share_module=True):
"""
Initialize a CdartsTrainer.
Parameters
----------
model_small : nn.Module
PyTorch model to be trained. This is the search network of CDARTS.
model_large : nn.Module
PyTorch model to be trained. This is the evaluation network of CDARTS.
criterion : callable
Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``.
loaders : list of torch.utils.data.DataLoader
List of train data and valid data loaders, for training weights and architecture weights respectively.
samplers : list of torch.utils.data.Sampler
List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed.
In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details.
logger : logging.Logger
The logger for logging. Will use nni logger by default (if logger is ``None``).
regular_coeff : float
The coefficient of regular loss.
regular_ratio : float
The ratio of regular loss.
warmup_epochs : int
The epochs to warmup the search network
fix_head : bool
``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads.
epochs : int
Number of epochs planned for training.
steps_per_epoch : int
Steps of one epoch.
loss_alpha : float
The loss coefficient.
loss_T : float
The loss coefficient.
distributed : bool
``True`` if using distributed training, else non-distributed training.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping for weights.
interactive_type : string
``kl`` or ``smoothl1``.
output_path : string
Log storage path.
w_lr : float
Learning rate of the search network parameters.
w_momentum : float
Momentum of the search and the evaluation network.
w_weight_decay : float
The weight decay the search and the evaluation network parameters.
alpha_lr : float
Learning rate of the architecture parameters.
alpha_weight_decay : float
The weight decay the architecture parameters.
nasnet_lr : float
Learning rate of the evaluation network parameters.
local_rank : int
The number of thread.
share_module : bool
``True`` if sharing the stem and auxiliary heads, else not sharing these modules.
"""
if logger is None:
logger = logging.getLogger(__name__)
train_loader, valid_loader = loaders
train_sampler, valid_sampler = samplers
self.train_loader = CyclicIterator(train_loader, train_sampler, distributed)
self.valid_loader = CyclicIterator(valid_loader, valid_sampler, distributed)
self.regular_coeff = regular_coeff
self.regular_ratio = regular_ratio
self.warmup_epochs = warmup_epochs
self.fix_head = fix_head
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch
if self.steps_per_epoch is None:
self.steps_per_epoch = min(len(self.train_loader), len(self.valid_loader))
self.loss_alpha = loss_alpha
self.grad_clip = grad_clip
if interactive_type == "kl":
self.interactive_loss = InteractiveKLLoss(loss_T)
elif interactive_type == "smoothl1":
self.interactive_loss = nn.SmoothL1Loss()
self.loss_T = loss_T
self.distributed = distributed
self.log_frequency = log_frequency
self.main_proc = not distributed or local_rank == 0
self.logger = logger
self.checkpoint_dir = output_path
if self.main_proc:
os.makedirs(self.checkpoint_dir, exist_ok=True)
if distributed:
torch.distributed.barrier()
self.model_small = model_small
self.model_large = model_large
if self.fix_head:
for param in self.model_small.aux_head.parameters():
param.requires_grad = False
for param in self.model_large.aux_head.parameters():
param.requires_grad = False
self.mutator_small = RegularizedDartsMutator(self.model_small).cuda()
self.mutator_large = DartsDiscreteMutator(self.model_large, self.mutator_small).cuda()
self.criterion = criterion
self.optimizer_small = torch.optim.SGD(self.model_small.parameters(), w_lr,
momentum=w_momentum, weight_decay=w_weight_decay)
self.optimizer_large = torch.optim.SGD(self.model_large.parameters(), nasnet_lr,
momentum=w_momentum, weight_decay=w_weight_decay)
self.optimizer_alpha = torch.optim.Adam(self.mutator_small.parameters(), alpha_lr,
betas=(0.5, 0.999), weight_decay=alpha_weight_decay)
if distributed:
apex.parallel.convert_syncbn_model(self.model_small)
apex.parallel.convert_syncbn_model(self.model_large)
self.model_small = DistributedDataParallel(self.model_small, delay_allreduce=True)
self.model_large = DistributedDataParallel(self.model_large, delay_allreduce=True)
self.mutator_small = RegularizedMutatorParallel(self.mutator_small, delay_allreduce=True)
if share_module:
self.model_small.callback_queued = True
self.model_large.callback_queued = True
# mutator large never gets optimized, so do not need parallelized
def _warmup(self, phase, epoch):
assert phase in [PHASE_SMALL, PHASE_LARGE]
if phase == PHASE_SMALL:
model, optimizer = self.model_small, self.optimizer_small
elif phase == PHASE_LARGE:
model, optimizer = self.model_large, self.optimizer_large
model.train()
meters = AverageMeterGroup()
for step in range(self.steps_per_epoch):
x, y = next(self.train_loader)
x, y = x.cuda(), y.cuda()
optimizer.zero_grad()
logits_main, _ = model(x)
loss = self.criterion(logits_main, y)
loss.backward()
self._clip_grad_norm(model)
optimizer.step()
prec1, prec5 = accuracy(logits_main, y, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = reduce_metrics(metrics, self.distributed)
meters.update(metrics)
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
self.logger.info("Epoch [%d/%d] Step [%d/%d] (%s) %s", epoch + 1, self.epochs,
step + 1, self.steps_per_epoch, phase, meters)
def _clip_grad_norm(self, model):
if isinstance(model, DistributedDataParallel):
nn.utils.clip_grad_norm_(model.module.parameters(), self.grad_clip)
else:
nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
def _reset_nan(self, parameters):
with torch.no_grad():
for param in parameters:
for i, p in enumerate(param):
if p != p: # equivalent to `isnan(p)`
param[i] = float("-inf")
def _joint_train(self, epoch):
self.model_large.train()
self.model_small.train()
meters = AverageMeterGroup()
for step in range(self.steps_per_epoch):
trn_x, trn_y = next(self.train_loader)
val_x, val_y = next(self.valid_loader)
trn_x, trn_y = trn_x.cuda(), trn_y.cuda()
val_x, val_y = val_x.cuda(), val_y.cuda()
# step 1. optimize architecture
self.optimizer_alpha.zero_grad()
self.optimizer_large.zero_grad()
reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / (
(self.epochs - self.warmup_epochs) * self.regular_ratio)), 0)
loss_regular = self.mutator_small.reset_with_loss()
if loss_regular:
loss_regular *= reg_decay
logits_search, emsemble_logits_search = self.model_small(val_x)
logits_main, emsemble_logits_main = self.model_large(val_x)
loss_cls = (self.criterion(logits_search, val_y) + self.criterion(logits_main, val_y)) / self.loss_alpha
loss_interactive = self.interactive_loss(emsemble_logits_search, emsemble_logits_main) * (self.loss_T ** 2) * self.loss_alpha
loss = loss_cls + loss_interactive + loss_regular
loss.backward()
self._clip_grad_norm(self.model_large)
self.optimizer_large.step()
self.optimizer_alpha.step()
# NOTE: need to call here `self._reset_nan(self.mutator_small.parameters())` if `cut_choices`
# step 2. optimize op weights
self.optimizer_small.zero_grad()
with torch.no_grad():
# resample architecture since parameters have been changed
self.mutator_small.reset_with_loss()
logits_search_train, _ = self.model_small(trn_x)
loss_weight = self.criterion(logits_search_train, trn_y)
loss_weight.backward()
self._clip_grad_norm(self.model_small)
self.optimizer_small.step()
metrics = {"loss_cls": loss_cls, "loss_interactive": loss_interactive,
"loss_regular": loss_regular, "loss_weight": loss_weight}
metrics = reduce_metrics(metrics, self.distributed)
meters.update(metrics)
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
self.logger.info("Epoch [%d/%d] Step [%d/%d] (joint) %s", epoch + 1, self.epochs,
step + 1, self.steps_per_epoch, meters)
def train(self):
for epoch in range(self.epochs):
if epoch < self.warmup_epochs:
with torch.no_grad(): # otherwise grads will be retained on the architecture params
self.mutator_small.reset_with_loss()
self._warmup(PHASE_SMALL, epoch)
else:
with torch.no_grad():
self.mutator_large.reset()
self._warmup(PHASE_LARGE, epoch)
self._joint_train(epoch)
self.export(os.path.join(self.checkpoint_dir, "epoch_{:02d}.json".format(epoch)),
os.path.join(self.checkpoint_dir, "epoch_{:02d}.genotypes".format(epoch)))
def export(self, file, genotype_file):
if self.main_proc:
mutator_export, genotypes = self.mutator_small.export(self.logger)
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
with open(genotype_file, "w") as f:
f.write(str(genotypes))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import os
import torch
import torch.distributed as dist
class CyclicIterator:
def __init__(self, loader, sampler, distributed):
self.loader = loader
self.sampler = sampler
self.epoch = 0
self.distributed = distributed
self._next_epoch()
def _next_epoch(self):
if self.distributed:
self.sampler.set_epoch(self.epoch)
self.iterator = iter(self.loader)
self.epoch += 1
def __len__(self):
return len(self.loader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iterator)
except StopIteration:
self._next_epoch()
return next(self.iterator)
class TorchTensorEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
if isinstance(o, torch.Tensor):
return o.tolist()
return super().default(o)
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(1.0 / batch_size))
return res
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= float(os.environ["WORLD_SIZE"])
return rt
def reduce_metrics(metrics, distributed=False):
if distributed:
return {k: reduce_tensor(v).item() for k, v in metrics.items()}
return {k: v.item() for k, v in metrics.items()}
...@@ -6,10 +6,10 @@ import sys ...@@ -6,10 +6,10 @@ import sys
import json import json
import time import time
import subprocess import subprocess
import json_tricks
from ..common import init_logger from ..common import init_logger
from ..env_vars import trial_env_vars from ..env_vars import trial_env_vars
from ..utils import to_json
_sysdir = trial_env_vars.NNI_SYS_DIR _sysdir = trial_env_vars.NNI_SYS_DIR
if not os.path.exists(os.path.join(_sysdir, '.nni')): if not os.path.exists(os.path.join(_sysdir, '.nni')):
...@@ -30,7 +30,7 @@ _multiphase = trial_env_vars.MULTI_PHASE ...@@ -30,7 +30,7 @@ _multiphase = trial_env_vars.MULTI_PHASE
_param_index = 0 _param_index = 0
def request_next_parameter(): def request_next_parameter():
metric = json_tricks.dumps({ metric = to_json({
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'REQUEST_PARAMETER', 'type': 'REQUEST_PARAMETER',
'sequence': 0, 'sequence': 0,
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import json_tricks from .utils import to_json
from .env_vars import trial_env_vars from .env_vars import trial_env_vars
from . import platform from . import platform
...@@ -110,7 +109,7 @@ def report_intermediate_result(metric): ...@@ -110,7 +109,7 @@ def report_intermediate_result(metric):
global _intermediate_seq global _intermediate_seq
assert _params or trial_env_vars.NNI_PLATFORM is None, \ assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_intermediate_result' 'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric = json_tricks.dumps({ metric = to_json({
'parameter_id': _params['parameter_id'] if _params else None, 'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL', 'type': 'PERIODICAL',
...@@ -120,7 +119,6 @@ def report_intermediate_result(metric): ...@@ -120,7 +119,6 @@ def report_intermediate_result(metric):
_intermediate_seq += 1 _intermediate_seq += 1
platform.send_metric(metric) platform.send_metric(metric)
def report_final_result(metric): def report_final_result(metric):
""" """
Reports final result to NNI. Reports final result to NNI.
...@@ -132,7 +130,7 @@ def report_final_result(metric): ...@@ -132,7 +130,7 @@ def report_final_result(metric):
""" """
assert _params or trial_env_vars.NNI_PLATFORM is None, \ assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_final_result' 'nni.get_next_parameter() needs to be called before report_final_result'
metric = json_tricks.dumps({ metric = to_json({
'parameter_id': _params['parameter_id'] if _params else None, 'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL', 'type': 'FINAL',
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
"""
utils.py
"""
import os import os
import functools
from enum import Enum, unique from enum import Enum, unique
import json_tricks
from .common import init_logger from .common import init_logger
from .env_vars import dispatcher_env_vars from .env_vars import dispatcher_env_vars
to_json = functools.partial(json_tricks.dumps, allow_nan=True)
@unique @unique
class OptimizeMode(Enum): class OptimizeMode(Enum):
"""Optimize Mode class """Optimize Mode class
......
{ {
"choice_str": { "choice_str": {
"_type": "choice", "_type": "choice",
"_value": ["cat", "dog", "elephant", "cow", "sheep", "panda"], "_value": ["cat", "dog", "elephant", "cow", "sheep", "panda"]
"fail": ["metis", "gp"]
}, },
"choice_int": { "choice_int": {
"_type": "choice", "_type": "choice",
...@@ -10,8 +9,7 @@ ...@@ -10,8 +9,7 @@
}, },
"choice_mixed": { "choice_mixed": {
"_type": "choice", "_type": "choice",
"_value": [0.3, "cat", 1, null], "_value": [0.3, "cat", 1, null]
"fail": ["metis", "gp"]
}, },
"choice_float": { "choice_float": {
"_type": "choice", "_type": "choice",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment