Commit 1fd61e51 authored by v-liguo's avatar v-liguo
Browse files

Merge branch 'master' of https://github.com/Microsoft/nni into dev-fix-hyper-bug

merge
parents bf8d6dc6 8016c710
...@@ -8,6 +8,8 @@ Trial ...@@ -8,6 +8,8 @@ Trial
.. autofunction:: nni.get_current_parameter .. autofunction:: nni.get_current_parameter
.. autofunction:: nni.report_intermediate_result .. autofunction:: nni.report_intermediate_result
.. autofunction:: nni.report_final_result .. autofunction:: nni.report_final_result
.. autofunction:: nni.get_experiment_id
.. autofunction:: nni.get_trial_id
.. autofunction:: nni.get_sequence_id .. autofunction:: nni.get_sequence_id
......
...@@ -201,6 +201,10 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -201,6 +201,10 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
throw new Error('Kubeflow Cluster config is not initialized'); throw new Error('Kubeflow Cluster config is not initialized');
} }
if (this.fcTrialConfig === undefined) {
throw new Error('Kubeflow trial config is not initialized');
}
let trialJobOutputUrl: string = ''; let trialJobOutputUrl: string = '';
if (this.fcClusterConfig.storageType === 'azureStorage') { if (this.fcClusterConfig.storageType === 'azureStorage') {
...@@ -208,12 +212,15 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -208,12 +212,15 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
throw new Error('azureStorageClient is not initialized'); throw new Error('azureStorageClient is not initialized');
} }
try { try {
//upload local files to azure storage //upload local files, including scripts for running the trial and configuration (e.g., hyperparameters) for the trial, to azure storage
await AzureStorageClientUtility.uploadDirectory( await AzureStorageClientUtility.uploadDirectory(
this.azureStorageClient, `nni/${getExperimentId()}/${trialJobId}`, this.azureStorageShare, `${trialLocalTempFolder}`); this.azureStorageClient, `nni/${getExperimentId()}/${trialJobId}`, this.azureStorageShare, `${trialLocalTempFolder}`);
//upload code files to azure storage
await AzureStorageClientUtility.uploadDirectory(
this.azureStorageClient, `nni/${getExperimentId()}/${trialJobId}`, this.azureStorageShare, `${this.fcTrialConfig.codeDir}`);
trialJobOutputUrl = `https://${this.azureStorageAccountName}.file.core.windows.net/\ trialJobOutputUrl = `https://${this.azureStorageAccountName}.file.core.windows.net/` +
${this.azureStorageShare}/${path.join('nni', getExperimentId(), trialJobId, 'output')}`; `${this.azureStorageShare}/${path.join('nni', getExperimentId(), trialJobId, 'output')}`;
} catch (error) { } catch (error) {
this.log.error(error); this.log.error(error);
...@@ -226,7 +233,8 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -226,7 +233,8 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
await cpp.exec(`mkdir -p ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}`); await cpp.exec(`mkdir -p ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}`);
// Copy code files from local dir to NFS mounted dir // Copy code files from local dir to NFS mounted dir
await cpp.exec(`cp -r ${trialLocalTempFolder}/* ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}/.`); await cpp.exec(`cp -r ${trialLocalTempFolder}/* ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}/.`);
// Copy codeDir to NFS mounted dir
await cpp.exec(`cp -r ${this.fcTrialConfig.codeDir}/* ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}/.`);
const nfsConfig: NFSConfig = nfsFrameworkControllerClusterConfig.nfs; const nfsConfig: NFSConfig = nfsFrameworkControllerClusterConfig.nfs;
trialJobOutputUrl = `nfs://${nfsConfig.server}:${path.join(nfsConfig.path, 'nni', getExperimentId(), trialJobId, 'output')}`; trialJobOutputUrl = `nfs://${nfsConfig.server}:${path.join(nfsConfig.path, 'nni', getExperimentId(), trialJobId, 'output')}`;
} }
...@@ -257,13 +265,12 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -257,13 +265,12 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
throw new Error('frameworkcontroller trial config is not initialized'); throw new Error('frameworkcontroller trial config is not initialized');
} }
await cpp.exec(`mkdir -p ${path.dirname(trialLocalTempFolder)}`); await cpp.exec(`mkdir -p ${trialLocalTempFolder}`);
await cpp.exec(`cp -r ${this.fcTrialConfig.codeDir} ${trialLocalTempFolder}`);
const installScriptContent : string = CONTAINER_INSTALL_NNI_SHELL_FORMAT; const installScriptContent : string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files // Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), installScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), installScriptContent, { encoding: 'utf8' });
// Create tmp trial working folder locally. // Create tmp trial working folder locally.
await cpp.exec(`mkdir -p ${trialLocalTempFolder}`);
for (const taskRole of this.fcTrialConfig.taskRoles) { for (const taskRole of this.fcTrialConfig.taskRoles) {
const runScriptContent: string = const runScriptContent: string =
......
...@@ -201,6 +201,10 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -201,6 +201,10 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
throw new Error('Kubeflow Cluster config is not initialized'); throw new Error('Kubeflow Cluster config is not initialized');
} }
if (this.kubeflowTrialConfig === undefined) {
throw new Error('Kubeflow Trial config is not initialized');
}
let trialJobOutputUrl: string = ''; let trialJobOutputUrl: string = '';
assert(this.kubeflowClusterConfig.storage === undefined assert(this.kubeflowClusterConfig.storage === undefined
...@@ -212,13 +216,17 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -212,13 +216,17 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
throw new Error('azureStorageClient is not initialized'); throw new Error('azureStorageClient is not initialized');
} }
try { try {
//upload local files to azure storage //upload local files, including scripts for running the trial and configuration (e.g., hyperparameters) for the trial, to azure storage
await AzureStorageClientUtility.uploadDirectory(this.azureStorageClient, await AzureStorageClientUtility.uploadDirectory(this.azureStorageClient,
`nni/${getExperimentId()}/${trialJobId}`, this.azureStorageShare, `nni/${getExperimentId()}/${trialJobId}`, this.azureStorageShare,
`${trialLocalTempFolder}`); `${trialLocalTempFolder}`);
//upload code files to azure storage
await AzureStorageClientUtility.uploadDirectory(this.azureStorageClient,
`nni/${getExperimentId()}/${trialJobId}`, this.azureStorageShare,
`${this.kubeflowTrialConfig.codeDir}`);
trialJobOutputUrl = `https://${this.azureStorageAccountName}.file.core.windows.net/${this.azureStorageShare}\ trialJobOutputUrl = `https://${this.azureStorageAccountName}.file.core.windows.net/${this.azureStorageShare}` +
/${path.join('nni', getExperimentId(), trialJobId, 'output')}`; `/${path.join('nni', getExperimentId(), trialJobId, 'output')}`;
} catch (error) { } catch (error) {
this.log.error(error); this.log.error(error);
...@@ -228,9 +236,10 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -228,9 +236,10 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
const nfsKubeflowClusterConfig: KubeflowClusterConfigNFS = <KubeflowClusterConfigNFS>this.kubeflowClusterConfig; const nfsKubeflowClusterConfig: KubeflowClusterConfigNFS = <KubeflowClusterConfigNFS>this.kubeflowClusterConfig;
// Creat work dir for current trial in NFS directory // Creat work dir for current trial in NFS directory
await cpp.exec(`mkdir -p ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}`); await cpp.exec(`mkdir -p ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}`);
// Copy code files from local dir to NFS mounted dir // Copy script files from local dir to NFS mounted dir
await cpp.exec(`cp -r ${trialLocalTempFolder}/* ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}/.`); await cpp.exec(`cp -r ${trialLocalTempFolder}/* ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}/.`);
// Copy codeDir to NFS mounted dir
await cpp.exec(`cp -r ${this.kubeflowTrialConfig.codeDir}/* ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}/.`);
const nfsConfig: NFSConfig = nfsKubeflowClusterConfig.nfs; const nfsConfig: NFSConfig = nfsKubeflowClusterConfig.nfs;
trialJobOutputUrl = `nfs://${nfsConfig.server}:${path.join(nfsConfig.path, 'nni', getExperimentId(), trialJobId, 'output')}`; trialJobOutputUrl = `nfs://${nfsConfig.server}:${path.join(nfsConfig.path, 'nni', getExperimentId(), trialJobId, 'output')}`;
} }
...@@ -255,13 +264,10 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -255,13 +264,10 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
} }
//create tmp trial working folder locally. //create tmp trial working folder locally.
await cpp.exec(`mkdir -p ${path.dirname(trialLocalTempFolder)}`); await cpp.exec(`mkdir -p ${trialLocalTempFolder}`);
await cpp.exec(`cp -r ${kubeflowTrialConfig.codeDir} ${trialLocalTempFolder}`);
const runScriptContent : string = CONTAINER_INSTALL_NNI_SHELL_FORMAT; const runScriptContent : string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files // Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Create tmp trial working folder locally.
await cpp.exec(`mkdir -p ${trialLocalTempFolder}`);
// Write worker file content run_worker.sh to local tmp folders // Write worker file content run_worker.sh to local tmp folders
if (kubeflowTrialConfig.worker !== undefined) { if (kubeflowTrialConfig.worker !== undefined) {
......
...@@ -26,7 +26,7 @@ import * as path from 'path'; ...@@ -26,7 +26,7 @@ import * as path from 'path';
import * as ts from 'tail-stream'; import * as ts from 'tail-stream';
import * as tkill from 'tree-kill'; import * as tkill from 'tree-kill';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { getInitTrialSequenceId } from '../../common/experimentStartupInfo'; import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { import {
HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm, HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm,
...@@ -126,6 +126,7 @@ class LocalTrainingService implements TrainingService { ...@@ -126,6 +126,7 @@ class LocalTrainingService implements TrainingService {
private stopping: boolean; private stopping: boolean;
private rootDir!: string; private rootDir!: string;
private trialSequenceId: number; private trialSequenceId: number;
private readonly experimentId! : string;
private gpuScheduler!: GPUScheduler; private gpuScheduler!: GPUScheduler;
private readonly occupiedGpuIndexNumMap: Map<number, number>; private readonly occupiedGpuIndexNumMap: Map<number, number>;
private designatedGpuIndices!: Set<number>; private designatedGpuIndices!: Set<number>;
...@@ -145,6 +146,7 @@ class LocalTrainingService implements TrainingService { ...@@ -145,6 +146,7 @@ class LocalTrainingService implements TrainingService {
this.stopping = false; this.stopping = false;
this.log = getLogger(); this.log = getLogger();
this.trialSequenceId = -1; this.trialSequenceId = -1;
this.experimentId = getExperimentId();
this.jobStreamMap = new Map<string, ts.Stream>(); this.jobStreamMap = new Map<string, ts.Stream>();
this.log.info('Construct local machine training service.'); this.log.info('Construct local machine training service.');
this.occupiedGpuIndexNumMap = new Map<number, number>(); this.occupiedGpuIndexNumMap = new Map<number, number>();
...@@ -400,6 +402,7 @@ class LocalTrainingService implements TrainingService { ...@@ -400,6 +402,7 @@ class LocalTrainingService implements TrainingService {
resource: { gpuIndices: number[] }): { key: string; value: string }[] { resource: { gpuIndices: number[] }): { key: string; value: string }[] {
const envVariables: { key: string; value: string }[] = [ const envVariables: { key: string; value: string }[] = [
{ key: 'NNI_PLATFORM', value: 'local' }, { key: 'NNI_PLATFORM', value: 'local' },
{ key: 'NNI_EXP_ID', value: this.experimentId },
{ key: 'NNI_SYS_DIR', value: trialJobDetail.workingDirectory }, { key: 'NNI_SYS_DIR', value: trialJobDetail.workingDirectory },
{ key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id }, { key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id },
{ key: 'NNI_OUTPUT_DIR', value: trialJobDetail.workingDirectory }, { key: 'NNI_OUTPUT_DIR', value: trialJobDetail.workingDirectory },
......
...@@ -427,7 +427,7 @@ class BOHB(MsgDispatcherBase): ...@@ -427,7 +427,7 @@ class BOHB(MsgDispatcherBase):
send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret)) send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret))
return None return None
assert self.generated_hyper_configs assert self.generated_hyper_configs
params = self.generated_hyper_configs.pop() params = self.generated_hyper_configs.pop(0)
ret = { ret = {
'parameter_id': params[0], 'parameter_id': params[0],
'parameter_source': 'algorithm', 'parameter_source': 'algorithm',
......
...@@ -24,6 +24,7 @@ from collections import namedtuple ...@@ -24,6 +24,7 @@ from collections import namedtuple
_trial_env_var_names = [ _trial_env_var_names = [
'NNI_PLATFORM', 'NNI_PLATFORM',
'NNI_EXP_ID',
'NNI_TRIAL_JOB_ID', 'NNI_TRIAL_JOB_ID',
'NNI_SYS_DIR', 'NNI_SYS_DIR',
'NNI_OUTPUT_DIR', 'NNI_OUTPUT_DIR',
......
...@@ -340,7 +340,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -340,7 +340,7 @@ class Hyperband(MsgDispatcherBase):
self.curr_s -= 1 self.curr_s -= 1
assert self.generated_hyper_configs assert self.generated_hyper_configs
params = self.generated_hyper_configs.pop() params = self.generated_hyper_configs.pop(0)
ret = { ret = {
'parameter_id': params[0], 'parameter_id': params[0],
'parameter_source': 'algorithm', 'parameter_source': 'algorithm',
......
...@@ -94,5 +94,11 @@ def send_metric(string): ...@@ -94,5 +94,11 @@ def send_metric(string):
else: else:
subprocess.run(['touch', _metric_file.name], check = True) subprocess.run(['touch', _metric_file.name], check = True)
def get_experiment_id():
return trial_env_vars.NNI_EXP_ID
def get_trial_id():
return trial_env_vars.NNI_TRIAL_JOB_ID
def get_sequence_id(): def get_sequence_id():
return trial_env_vars.NNI_TRIAL_SEQ_ID return trial_env_vars.NNI_TRIAL_SEQ_ID
...@@ -25,6 +25,12 @@ import json_tricks ...@@ -25,6 +25,12 @@ import json_tricks
def get_next_parameter(): def get_next_parameter():
pass pass
def get_experiment_id():
pass
def get_trial_id():
pass
def get_sequence_id(): def get_sequence_id():
pass pass
......
...@@ -32,6 +32,12 @@ _last_metric = None ...@@ -32,6 +32,12 @@ _last_metric = None
def get_next_parameter(): def get_next_parameter():
return _params return _params
def get_experiment_id():
return 'fakeidex'
def get_trial_id():
return 'fakeidtr'
def get_sequence_id(): def get_sequence_id():
return 0 return 0
......
...@@ -30,11 +30,15 @@ __all__ = [ ...@@ -30,11 +30,15 @@ __all__ = [
'get_current_parameter', 'get_current_parameter',
'report_intermediate_result', 'report_intermediate_result',
'report_final_result', 'report_final_result',
'get_experiment_id',
'get_trial_id',
'get_sequence_id' 'get_sequence_id'
] ]
_params = None _params = None
_experiment_id = platform.get_experiment_id()
_trial_id = platform.get_trial_id()
_sequence_id = platform.get_sequence_id() _sequence_id = platform.get_sequence_id()
...@@ -52,6 +56,12 @@ def get_current_parameter(tag): ...@@ -52,6 +56,12 @@ def get_current_parameter(tag):
return None return None
return _params['parameters'][tag] return _params['parameters'][tag]
def get_experiment_id():
return _experiment_id
def get_trial_id():
return _trial_id
def get_sequence_id(): def get_sequence_id():
return _sequence_id return _sequence_id
......
...@@ -38,6 +38,12 @@ class TrialTestCase(TestCase): ...@@ -38,6 +38,12 @@ class TrialTestCase(TestCase):
nni.get_next_parameter() nni.get_next_parameter()
self.assertEqual(nni.get_current_parameter('x'), 123) self.assertEqual(nni.get_current_parameter('x'), 123)
def test_get_experiment_id(self):
self.assertEqual(nni.get_experiment_id(), 'fakeidex')
def test_get_trial_id(self):
self.assertEqual(nni.get_trial_id(), 'fakeidtr')
def test_get_sequence_id(self): def test_get_sequence_id(self):
self.assertEqual(nni.get_sequence_id(), 0) self.assertEqual(nni.get_sequence_id(), 0)
......
...@@ -104,7 +104,8 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -104,7 +104,8 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
Object.keys(trialJobs).map(item => { Object.keys(trialJobs).map(item => {
let desc: Parameters = { let desc: Parameters = {
parameters: {}, parameters: {},
intermediate: [] intermediate: [],
progress: 1
}; };
let duration = 0; let duration = 0;
const id = trialJobs[item].id !== undefined const id = trialJobs[item].id !== undefined
...@@ -125,6 +126,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -125,6 +126,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
const tempHyper = trialJobs[item].hyperParameters; const tempHyper = trialJobs[item].hyperParameters;
if (tempHyper !== undefined) { if (tempHyper !== undefined) {
const getPara = JSON.parse(tempHyper[tempHyper.length - 1]).parameters; const getPara = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
desc.progress = tempHyper.length;
if (typeof getPara === 'string') { if (typeof getPara === 'string') {
desc.parameters = JSON.parse(getPara); desc.parameters = JSON.parse(getPara);
} else { } else {
......
...@@ -97,6 +97,8 @@ class OpenRow extends React.Component<OpenRowProps, OpenRowState> { ...@@ -97,6 +97,8 @@ class OpenRow extends React.Component<OpenRowProps, OpenRowState> {
<br /> <br />
For the entire parameter set, please refer to the following " For the entire parameter set, please refer to the following "
<a href={trialink} target="_blank">{trialink}</a>". <a href={trialink} target="_blank">{trialink}</a>".
<br/>
Current Phase: {record.description.progress}.
</Row> </Row>
: :
<div /> <div />
......
...@@ -27,6 +27,7 @@ interface Parameters { ...@@ -27,6 +27,7 @@ interface Parameters {
parameters: ErrorParameter; parameters: ErrorParameter;
logPath?: string; logPath?: string;
intermediate: Array<number>; intermediate: Array<number>;
progress?: number;
} }
interface Experiment { interface Experiment {
......
...@@ -84,10 +84,12 @@ tuner_schema_dict = { ...@@ -84,10 +84,12 @@ tuner_schema_dict = {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'), 'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('population_size'): setNumberRange('population_size', int, 0, 99999), Optional('population_size'): setNumberRange('population_size', int, 0, 99999),
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}, },
('BatchTuner', 'GridSearch', 'Random'): { ('BatchTuner', 'GridSearch', 'Random'): {
'builtinTunerName': setChoice('builtinTunerName', 'BatchTuner', 'GridSearch', 'Random'), 'builtinTunerName': setChoice('builtinTunerName', 'BatchTuner', 'GridSearch', 'Random'),
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}, },
'NetworkMorphism': { 'NetworkMorphism': {
...@@ -99,6 +101,7 @@ tuner_schema_dict = { ...@@ -99,6 +101,7 @@ tuner_schema_dict = {
Optional('input_channel'): setType('input_channel', int), Optional('input_channel'): setType('input_channel', int),
Optional('n_output_node'): setType('n_output_node', int), Optional('n_output_node'): setType('n_output_node', int),
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}, },
'MetisTuner': { 'MetisTuner': {
...@@ -110,6 +113,7 @@ tuner_schema_dict = { ...@@ -110,6 +113,7 @@ tuner_schema_dict = {
Optional('selection_num_starting_points'): setType('selection_num_starting_points', int), Optional('selection_num_starting_points'): setType('selection_num_starting_points', int),
Optional('cold_start_num'): setType('cold_start_num', int), Optional('cold_start_num'): setType('cold_start_num', int),
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}, },
'GPTuner': { 'GPTuner': {
...@@ -125,6 +129,7 @@ tuner_schema_dict = { ...@@ -125,6 +129,7 @@ tuner_schema_dict = {
Optional('selection_num_warm_up'): setType('selection_num_warm_up', int), Optional('selection_num_warm_up'): setType('selection_num_warm_up', int),
Optional('selection_num_starting_points'): setType('selection_num_starting_points', int), Optional('selection_num_starting_points'): setType('selection_num_starting_points', int),
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}, },
'customized': { 'customized': {
...@@ -132,6 +137,7 @@ tuner_schema_dict = { ...@@ -132,6 +137,7 @@ tuner_schema_dict = {
'classFileName': setType('classFileName', str), 'classFileName': setType('classFileName', str),
'className': setType('className', str), 'className': setType('className', str),
Optional('classArgs'): dict, Optional('classArgs'): dict,
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment