Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
1c56fea8
"vscode:/vscode.git/clone" did not exist on "8d6f237018952a2aa3de56b5c523b598bf1a2fbb"
Unverified
Commit
1c56fea8
authored
Jun 24, 2019
by
chicm-ms
Committed by
GitHub
Jun 24, 2019
Browse files
Merge pull request #21 from microsoft/master
pull code
parents
12410686
97829ccd
Changes
63
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
169 additions
and
466 deletions
+169
-466
src/nni_manager/training_service/pai/paiData.ts
src/nni_manager/training_service/pai/paiData.ts
+4
-4
src/nni_manager/training_service/pai/paiJobRestServer.ts
src/nni_manager/training_service/pai/paiJobRestServer.ts
+38
-0
src/nni_manager/training_service/pai/paiTrainingService.ts
src/nni_manager/training_service/pai/paiTrainingService.ts
+64
-6
src/nni_manager/types/tail-stream/index.d.ts
src/nni_manager/types/tail-stream/index.d.ts
+2
-1
src/sdk/pynni/nni/__main__.py
src/sdk/pynni/nni/__main__.py
+4
-6
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
+2
-2
src/sdk/pynni/nni/common.py
src/sdk/pynni/nni/common.py
+8
-0
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
+2
-2
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
+2
-2
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
+2
-2
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
+2
-2
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+23
-7
src/sdk/pynni/nni/multi_phase/__init__.py
src/sdk/pynni/nni/multi_phase/__init__.py
+0
-0
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
+0
-198
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
+0
-106
src/sdk/pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
.../pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
+2
-2
src/sdk/pynni/nni/smac_tuner/smac_tuner.py
src/sdk/pynni/nni/smac_tuner/smac_tuner.py
+3
-3
src/sdk/pynni/nni/tuner.py
src/sdk/pynni/nni/tuner.py
+6
-6
src/sdk/pynni/tests/test_multi_phase_tuner.py
src/sdk/pynni/tests/test_multi_phase_tuner.py
+0
-110
src/sdk/pynni/tests/test_tuner.py
src/sdk/pynni/tests/test_tuner.py
+5
-7
No files found.
src/nni_manager/training_service/pai/paiData.ts
View file @
1c56fea8
...
...
@@ -64,11 +64,11 @@ else
fi`
;
export
const
PAI_TRIAL_COMMAND_FORMAT
:
string
=
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} \
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4}
MULTI_PHASE={5}
\
&& cd $NNI_SYS_DIR && sh install_nni.sh \
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{
5
}' --nnimanager_ip '{
6
}' --nnimanager_port '{
7
}' \
--pai_hdfs_output_dir '{
8
}' --pai_hdfs_host '{
9
}' --pai_user_name {1
0
} --nni_hdfs_exp_dir '{1
1
}' --webhdfs_path '/webhdfs/api/v1' \
--nni_manager_version '{1
2
}' --log_collection '{1
3
}'`
;
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{
6
}' --nnimanager_ip '{
7
}' --nnimanager_port '{
8
}' \
--pai_hdfs_output_dir '{
9
}' --pai_hdfs_host '{
10
}' --pai_user_name {1
1
} --nni_hdfs_exp_dir '{1
2
}' --webhdfs_path '/webhdfs/api/v1' \
--nni_manager_version '{1
3
}' --log_collection '{1
4
}'`
;
export
const
PAI_OUTPUT_DIR_FORMAT
:
string
=
`hdfs://{0}:9000/`
;
...
...
src/nni_manager/training_service/pai/paiJobRestServer.ts
View file @
1c56fea8
...
...
@@ -19,17 +19,26 @@
'
use strict
'
;
import
{
Request
,
Response
,
Router
}
from
'
express
'
;
import
{
Inject
}
from
'
typescript-ioc
'
;
import
*
as
component
from
'
../../common/component
'
;
import
{
ClusterJobRestServer
}
from
'
../common/clusterJobRestServer
'
;
import
{
PAITrainingService
}
from
'
./paiTrainingService
'
;
export
interface
ParameterFileMeta
{
readonly
experimentId
:
string
;
readonly
trialId
:
string
;
readonly
filePath
:
string
;
}
/**
* PAI Training service Rest server, provides rest API to support pai job metrics update
*
*/
@
component
.
Singleton
export
class
PAIJobRestServer
extends
ClusterJobRestServer
{
private
parameterFileMetaList
:
ParameterFileMeta
[]
=
[];
@
Inject
private
readonly
paiTrainingService
:
PAITrainingService
;
...
...
@@ -52,4 +61,33 @@ export class PAIJobRestServer extends ClusterJobRestServer {
});
}
}
protected
createRestHandler
():
Router
{
const
router
:
Router
=
super
.
createRestHandler
();
router
.
post
(
`/parameter-file-meta`
,
(
req
:
Request
,
res
:
Response
)
=>
{
try
{
this
.
log
.
info
(
`POST /parameter-file-meta, body is
${
JSON
.
stringify
(
req
.
body
)}
`
);
this
.
parameterFileMetaList
.
push
(
req
.
body
);
res
.
send
();
}
catch
(
err
)
{
this
.
log
.
error
(
`POST parameter-file-meta error:
${
err
}
`
);
res
.
status
(
500
);
res
.
send
(
err
.
message
);
}
});
router
.
get
(
`/parameter-file-meta`
,
(
req
:
Request
,
res
:
Response
)
=>
{
try
{
this
.
log
.
info
(
`GET /parameter-file-meta`
);
res
.
send
(
this
.
parameterFileMetaList
);
}
catch
(
err
)
{
this
.
log
.
error
(
`GET parameter-file-meta error:
${
err
}
`
);
res
.
status
(
500
);
res
.
send
(
err
.
message
);
}
});
return
router
;
}
}
src/nni_manager/training_service/pai/paiTrainingService.ts
View file @
1c56fea8
...
...
@@ -33,7 +33,7 @@ import { MethodNotImplementedError } from '../../common/errors';
import
{
getExperimentId
,
getInitTrialSequenceId
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
JobApplicationForm
,
NNIManagerIpConfig
,
TrainingService
,
HyperParameters
,
JobApplicationForm
,
NNIManagerIpConfig
,
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
}
from
'
../../common/trainingService
'
;
import
{
delay
,
generateParamFileName
,
...
...
@@ -45,7 +45,7 @@ import { HDFSClientUtility } from './hdfsClientUtility';
import
{
NNIPAITrialConfig
,
PAIClusterConfig
,
PAIJobConfig
,
PAITaskRole
}
from
'
./paiConfig
'
;
import
{
PAI_LOG_PATH_FORMAT
,
PAI_OUTPUT_DIR_FORMAT
,
PAI_TRIAL_COMMAND_FORMAT
,
PAITrialJobDetail
}
from
'
./paiData
'
;
import
{
PAIJobInfoCollector
}
from
'
./paiJobInfoCollector
'
;
import
{
PAIJobRestServer
}
from
'
./paiJobRestServer
'
;
import
{
PAIJobRestServer
,
ParameterFileMeta
}
from
'
./paiJobRestServer
'
;
import
*
as
WebHDFS
from
'
webhdfs
'
;
...
...
@@ -79,6 +79,7 @@ class PAITrainingService implements TrainingService {
private
copyExpCodeDirPromise
?:
Promise
<
void
>
;
private
versionCheck
:
boolean
=
true
;
private
logCollection
:
string
;
private
isMultiPhase
:
boolean
=
false
;
constructor
()
{
this
.
log
=
getLogger
();
...
...
@@ -179,12 +180,22 @@ class PAITrainingService implements TrainingService {
return
deferred
.
promise
;
}
public
updateTrialJob
(
trialJobId
:
string
,
form
:
JobApplicationForm
):
Promise
<
TrialJobDetail
>
{
throw
new
MethodNotImplementedError
();
public
async
updateTrialJob
(
trialJobId
:
string
,
form
:
JobApplicationForm
):
Promise
<
TrialJobDetail
>
{
const
trialJobDetail
:
undefined
|
TrialJobDetail
=
this
.
trialJobsMap
.
get
(
trialJobId
);
if
(
trialJobDetail
===
undefined
)
{
throw
new
Error
(
`updateTrialJob failed:
${
trialJobId
}
not found`
);
}
if
(
form
.
jobType
===
'
TRIAL
'
)
{
await
this
.
writeParameterFile
(
trialJobId
,
(
<
TrialJobApplicationForm
>
form
).
hyperParameters
);
}
else
{
throw
new
Error
(
`updateTrialJob failed: jobType
${
form
.
jobType
}
not supported.`
);
}
return
trialJobDetail
;
}
public
get
isMultiPhaseJobSupported
():
boolean
{
return
fals
e
;
return
tru
e
;
}
// tslint:disable:no-http-string
...
...
@@ -336,6 +347,9 @@ class PAITrainingService implements TrainingService {
case
TrialConfigMetadataKey
.
LOG_COLLECTION
:
this
.
logCollection
=
value
;
break
;
case
TrialConfigMetadataKey
.
MULTI_PHASE
:
this
.
isMultiPhase
=
(
value
===
'
true
'
||
value
===
'
True
'
);
break
;
default
:
//Reject for unknown keys
throw
new
Error
(
`Uknown key:
${
key
}
`
);
...
...
@@ -445,6 +459,7 @@ class PAITrainingService implements TrainingService {
trialJobId
,
this
.
experimentId
,
trialJobDetail
.
sequenceId
,
this
.
isMultiPhase
,
this
.
paiTrialConfig
.
command
,
nniManagerIp
,
this
.
paiRestServerPort
,
...
...
@@ -632,7 +647,50 @@ class PAITrainingService implements TrainingService {
return
Promise
.
race
([
timeoutDelay
,
deferred
.
promise
])
.
finally
(()
=>
{
clearTimeout
(
timeoutId
);
});
}
// tslint:enable:no-any no-unsafe-any no-http-string
private
async
writeParameterFile
(
trialJobId
:
string
,
hyperParameters
:
HyperParameters
):
Promise
<
void
>
{
if
(
this
.
paiClusterConfig
===
undefined
)
{
throw
new
Error
(
'
PAI Cluster config is not initialized
'
);
}
if
(
this
.
paiTrialConfig
===
undefined
)
{
throw
new
Error
(
'
PAI trial config is not initialized
'
);
}
const
trialLocalTempFolder
:
string
=
path
.
join
(
getExperimentRootDir
(),
'
trials-local
'
,
trialJobId
);
const
hpFileName
:
string
=
generateParamFileName
(
hyperParameters
);
const
localFilepath
:
string
=
path
.
join
(
trialLocalTempFolder
,
hpFileName
);
await
fs
.
promises
.
writeFile
(
localFilepath
,
hyperParameters
.
value
,
{
encoding
:
'
utf8
'
});
const
hdfsCodeDir
:
string
=
HDFSClientUtility
.
getHdfsTrialWorkDir
(
this
.
paiClusterConfig
.
userName
,
trialJobId
);
const
hdfsHpFilePath
:
string
=
path
.
join
(
hdfsCodeDir
,
hpFileName
);
await
HDFSClientUtility
.
copyFileToHdfs
(
localFilepath
,
hdfsHpFilePath
,
this
.
hdfsClient
);
await
this
.
postParameterFileMeta
({
experimentId
:
this
.
experimentId
,
trialId
:
trialJobId
,
filePath
:
hdfsHpFilePath
});
}
private
postParameterFileMeta
(
parameterFileMeta
:
ParameterFileMeta
):
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
restServer
:
PAIJobRestServer
=
component
.
get
(
PAIJobRestServer
);
const
req
:
request
.
Options
=
{
uri
:
`
${
restServer
.
endPoint
}${
restServer
.
apiRootUrl
}
/parameter-file-meta`
,
method
:
'
POST
'
,
json
:
true
,
body
:
parameterFileMeta
};
request
(
req
,
(
err
:
Error
,
res
:
request
.
Response
)
=>
{
if
(
err
)
{
deferred
.
reject
(
err
);
}
else
{
deferred
.
resolve
();
}
});
return
deferred
.
promise
;
}
}
export
{
PAITrainingService
};
src/nni_manager/types/tail-stream/index.d.ts
View file @
1c56fea8
declare
module
'
tail-stream
'
{
export
interface
Stream
{
on
(
type
:
'
data
'
,
callback
:
(
data
:
Buffer
)
=>
void
):
void
;
destroy
():
void
;
end
(
data
:
number
):
void
;
emit
(
data
:
string
):
void
;
}
export
function
createReadStream
(
path
:
string
):
Stream
;
}
\ No newline at end of file
src/sdk/pynni/nni/__main__.py
View file @
1c56fea8
...
...
@@ -28,9 +28,8 @@ import json
import
importlib
from
.constants
import
ModuleName
,
ClassName
,
ClassArgs
,
AdvisorModuleName
,
AdvisorClassName
from
nni.common
import
enable_multi_thread
from
nni.common
import
enable_multi_thread
,
enable_multi_phase
from
nni.msg_dispatcher
import
MsgDispatcher
from
nni.multi_phase.multi_phase_dispatcher
import
MultiPhaseMsgDispatcher
logger
=
logging
.
getLogger
(
'nni.main'
)
logger
.
debug
(
'START'
)
...
...
@@ -126,6 +125,8 @@ def main():
args
=
parse_args
()
if
args
.
multi_thread
:
enable_multi_thread
()
if
args
.
multi_phase
:
enable_multi_phase
()
if
args
.
advisor_class_name
:
# advisor is enabled and starts to run
...
...
@@ -180,9 +181,6 @@ def main():
if
assessor
is
None
:
raise
AssertionError
(
'Failed to create Assessor instance'
)
if
args
.
multi_phase
:
dispatcher
=
MultiPhaseMsgDispatcher
(
tuner
,
assessor
)
else
:
dispatcher
=
MsgDispatcher
(
tuner
,
assessor
)
try
:
...
...
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
View file @
1c56fea8
...
...
@@ -78,7 +78,7 @@ class BatchTuner(Tuner):
"""
self
.
values
=
self
.
is_valid
(
search_space
)
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
...
...
@@ -90,7 +90,7 @@ class BatchTuner(Tuner):
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
return
self
.
values
[
self
.
count
]
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
pass
def
import_data
(
self
,
data
):
...
...
src/sdk/pynni/nni/common.py
View file @
1c56fea8
...
...
@@ -69,6 +69,7 @@ def init_logger(logger_file_path, log_level_name='info'):
sys
.
stdout
=
_LoggerFileWrapper
(
logger_file
)
_multi_thread
=
False
_multi_phase
=
False
def
enable_multi_thread
():
global
_multi_thread
...
...
@@ -76,3 +77,10 @@ def enable_multi_thread():
def
multi_thread_enabled
():
return
_multi_thread
def
enable_multi_phase
():
global
_multi_phase
_multi_phase
=
True
def
multi_phase_enabled
():
return
_multi_phase
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
View file @
1c56fea8
...
...
@@ -188,7 +188,7 @@ class EvolutionTuner(Tuner):
self
.
searchspace_json
,
is_rand
,
self
.
random_state
)
self
.
population
.
append
(
Individual
(
config
=
config
))
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
...
...
@@ -232,7 +232,7 @@ class EvolutionTuner(Tuner):
config
=
split_index
(
total_config
)
return
config
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
'''Record the result from a trial
Parameters
...
...
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
View file @
1c56fea8
...
...
@@ -137,7 +137,7 @@ class GridSearchTuner(Tuner):
'''
self
.
expanded_search_space
=
self
.
json2parameter
(
search_space
)
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
self
.
count
+=
1
while
(
self
.
count
<=
len
(
self
.
expanded_search_space
)
-
1
):
_params_tuple
=
convert_dict2tuple
(
self
.
expanded_search_space
[
self
.
count
])
...
...
@@ -147,7 +147,7 @@ class GridSearchTuner(Tuner):
return
self
.
expanded_search_space
[
self
.
count
]
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
pass
def
import_data
(
self
,
data
):
...
...
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
View file @
1c56fea8
...
...
@@ -248,7 +248,7 @@ class HyperoptTuner(Tuner):
verbose
=
0
)
self
.
rval
.
catch_eval_exceptions
=
False
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""
Returns a set of trial (hyper-)parameters, as a serializable object.
...
...
@@ -269,7 +269,7 @@ class HyperoptTuner(Tuner):
params
=
split_index
(
total_params
)
return
params
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""
Record an observation of the objective function
...
...
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
View file @
1c56fea8
...
...
@@ -174,7 +174,7 @@ class MetisTuner(Tuner):
return
output
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Generate next parameter for trial
If the number of trial result is lower than cold start number,
metis will first random generate some parameters.
...
...
@@ -205,7 +205,7 @@ class MetisTuner(Tuner):
return
results
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""Tuner receive result from trial.
Parameters
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
1c56fea8
...
...
@@ -18,7 +18,6 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
os
import
logging
from
collections
import
defaultdict
import
json_tricks
...
...
@@ -26,7 +25,7 @@ import json_tricks
from
.protocol
import
CommandType
,
send
from
.msg_dispatcher_base
import
MsgDispatcherBase
from
.assessor
import
AssessResult
from
.common
import
multi_thread_enabled
from
.common
import
multi_thread_enabled
,
multi_phase_enabled
from
.env_vars
import
dispatcher_env_vars
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -61,13 +60,19 @@ def _create_parameter_id():
_next_parameter_id
+=
1
return
_next_parameter_id
-
1
def
_pack_parameter
(
parameter_id
,
params
,
customized
=
False
):
def
_pack_parameter
(
parameter_id
,
params
,
customized
=
False
,
trial_job_id
=
None
,
parameter_index
=
None
):
_trial_params
[
parameter_id
]
=
params
ret
=
{
'parameter_id'
:
parameter_id
,
'parameter_source'
:
'customized'
if
customized
else
'algorithm'
,
'parameters'
:
params
}
if
trial_job_id
is
not
None
:
ret
[
'trial_job_id'
]
=
trial_job_id
if
parameter_index
is
not
None
:
ret
[
'parameter_index'
]
=
parameter_index
else
:
ret
[
'parameter_index'
]
=
0
return
json_tricks
.
dumps
(
ret
)
class
MsgDispatcher
(
MsgDispatcherBase
):
...
...
@@ -133,8 +138,13 @@ class MsgDispatcher(MsgDispatcherBase):
elif
data
[
'type'
]
==
'PERIODICAL'
:
if
self
.
assessor
is
not
None
:
self
.
_handle_intermediate_metric_data
(
data
)
else
:
pass
elif
data
[
'type'
]
==
'REQUEST_PARAMETER'
:
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
param_id
=
_create_parameter_id
()
param
=
self
.
tuner
.
generate_parameters
(
param_id
,
trial_job_id
=
data
[
'trial_job_id'
])
send
(
CommandType
.
SendTrialJobParameter
,
_pack_parameter
(
param_id
,
param
,
trial_job_id
=
data
[
'trial_job_id'
],
parameter_index
=
data
[
'parameter_index'
]))
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
...
...
@@ -160,7 +170,13 @@ class MsgDispatcher(MsgDispatcherBase):
id_
=
data
[
'parameter_id'
]
value
=
data
[
'value'
]
if
id_
in
_customized_parameter_ids
:
if
multi_phase_enabled
():
self
.
tuner
.
receive_customized_trial_result
(
id_
,
_trial_params
[
id_
],
value
,
trial_job_id
=
data
[
'trial_job_id'
])
else
:
self
.
tuner
.
receive_customized_trial_result
(
id_
,
_trial_params
[
id_
],
value
)
else
:
if
multi_phase_enabled
():
self
.
tuner
.
receive_trial_result
(
id_
,
_trial_params
[
id_
],
value
,
trial_job_id
=
data
[
'trial_job_id'
])
else
:
self
.
tuner
.
receive_trial_result
(
id_
,
_trial_params
[
id_
],
value
)
...
...
src/sdk/pynni/nni/multi_phase/__init__.py
deleted
100644 → 0
View file @
12410686
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
deleted
100644 → 0
View file @
12410686
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
logging
from
collections
import
defaultdict
import
json_tricks
from
nni.protocol
import
CommandType
,
send
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.assessor
import
AssessResult
_logger
=
logging
.
getLogger
(
__name__
)
# Assessor global variables
_trial_history
=
defaultdict
(
dict
)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''
_ended_trials
=
set
()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def
_sort_history
(
history
):
ret
=
[
]
for
i
,
_
in
enumerate
(
history
):
if
i
in
history
:
ret
.
append
(
history
[
i
])
else
:
break
return
ret
# Tuner global variables
_next_parameter_id
=
0
_trial_params
=
{}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids
=
set
()
def
_create_parameter_id
():
global
_next_parameter_id
# pylint: disable=global-statement
_next_parameter_id
+=
1
return
_next_parameter_id
-
1
def
_pack_parameter
(
parameter_id
,
params
,
customized
=
False
,
trial_job_id
=
None
,
parameter_index
=
None
):
_trial_params
[
parameter_id
]
=
params
ret
=
{
'parameter_id'
:
parameter_id
,
'parameter_source'
:
'customized'
if
customized
else
'algorithm'
,
'parameters'
:
params
}
if
trial_job_id
is
not
None
:
ret
[
'trial_job_id'
]
=
trial_job_id
if
parameter_index
is
not
None
:
ret
[
'parameter_index'
]
=
parameter_index
else
:
ret
[
'parameter_index'
]
=
0
return
json_tricks
.
dumps
(
ret
)
class
MultiPhaseMsgDispatcher
(
MsgDispatcherBase
):
def
__init__
(
self
,
tuner
,
assessor
=
None
):
super
(
MultiPhaseMsgDispatcher
,
self
).
__init__
()
self
.
tuner
=
tuner
self
.
assessor
=
assessor
if
assessor
is
None
:
_logger
.
debug
(
'Assessor is not configured'
)
def
load_checkpoint
(
self
):
self
.
tuner
.
load_checkpoint
()
if
self
.
assessor
is
not
None
:
self
.
assessor
.
load_checkpoint
()
def
save_checkpoint
(
self
):
self
.
tuner
.
save_checkpoint
()
if
self
.
assessor
is
not
None
:
self
.
assessor
.
save_checkpoint
()
def
handle_initialize
(
self
,
data
):
'''
data is search space
'''
self
.
tuner
.
update_search_space
(
data
)
send
(
CommandType
.
Initialized
,
''
)
return
True
def
handle_request_trial_jobs
(
self
,
data
):
# data: number or trial jobs
ids
=
[
_create_parameter_id
()
for
_
in
range
(
data
)]
params_list
=
self
.
tuner
.
generate_multiple_parameters
(
ids
)
assert
len
(
ids
)
==
len
(
params_list
)
for
i
,
_
in
enumerate
(
ids
):
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
ids
[
i
],
params_list
[
i
]))
return
True
def
handle_update_search_space
(
self
,
data
):
self
.
tuner
.
update_search_space
(
data
)
return
True
def
handle_import_data
(
self
,
data
):
"""import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self
.
tuner
.
import_data
(
data
)
return
True
def
handle_add_customized_trial
(
self
,
data
):
# data: parameters
id_
=
_create_parameter_id
()
_customized_parameter_ids
.
add
(
id_
)
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
id_
,
data
,
customized
=
True
))
return
True
def
handle_report_metric_data
(
self
,
data
):
trial_job_id
=
data
[
'trial_job_id'
]
if
data
[
'type'
]
==
'FINAL'
:
id_
=
data
[
'parameter_id'
]
if
id_
in
_customized_parameter_ids
:
self
.
tuner
.
receive_customized_trial_result
(
id_
,
_trial_params
[
id_
],
data
[
'value'
],
trial_job_id
)
else
:
self
.
tuner
.
receive_trial_result
(
id_
,
_trial_params
[
id_
],
data
[
'value'
],
trial_job_id
)
elif
data
[
'type'
]
==
'PERIODICAL'
:
if
self
.
assessor
is
not
None
:
self
.
_handle_intermediate_metric_data
(
data
)
else
:
pass
elif
data
[
'type'
]
==
'REQUEST_PARAMETER'
:
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
param_id
=
_create_parameter_id
()
param
=
self
.
tuner
.
generate_parameters
(
param_id
,
trial_job_id
)
send
(
CommandType
.
SendTrialJobParameter
,
_pack_parameter
(
param_id
,
param
,
trial_job_id
=
data
[
'trial_job_id'
],
parameter_index
=
data
[
'parameter_index'
]))
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
return
True
def
handle_trial_end
(
self
,
data
):
trial_job_id
=
data
[
'trial_job_id'
]
_ended_trials
.
add
(
trial_job_id
)
if
trial_job_id
in
_trial_history
:
_trial_history
.
pop
(
trial_job_id
)
if
self
.
assessor
is
not
None
:
self
.
assessor
.
trial_end
(
trial_job_id
,
data
[
'event'
]
==
'SUCCEEDED'
)
if
self
.
tuner
is
not
None
:
self
.
tuner
.
trial_end
(
json_tricks
.
loads
(
data
[
'hyper_params'
])[
'parameter_id'
],
data
[
'event'
]
==
'SUCCEEDED'
,
trial_job_id
)
return
True
def
handle_import_data
(
self
,
data
):
pass
def
_handle_intermediate_metric_data
(
self
,
data
):
if
data
[
'type'
]
!=
'PERIODICAL'
:
return
True
if
self
.
assessor
is
None
:
return
True
trial_job_id
=
data
[
'trial_job_id'
]
if
trial_job_id
in
_ended_trials
:
return
True
history
=
_trial_history
[
trial_job_id
]
history
[
data
[
'sequence'
]]
=
data
[
'value'
]
ordered_history
=
_sort_history
(
history
)
if
len
(
ordered_history
)
<
data
[
'sequence'
]:
# no user-visible update since last time
return
True
try
:
result
=
self
.
assessor
.
assess_trial
(
trial_job_id
,
ordered_history
)
except
Exception
as
e
:
_logger
.
exception
(
'Assessor error'
)
if
isinstance
(
result
,
bool
):
result
=
AssessResult
.
Good
if
result
else
AssessResult
.
Bad
elif
not
isinstance
(
result
,
AssessResult
):
msg
=
'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
raise
RuntimeError
(
msg
%
type
(
result
))
if
result
is
AssessResult
.
Bad
:
_logger
.
debug
(
'BAD, kill %s'
,
trial_job_id
)
send
(
CommandType
.
KillTrialJob
,
json_tricks
.
dumps
(
trial_job_id
))
else
:
_logger
.
debug
(
'GOOD'
)
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
deleted
100644 → 0
View file @
12410686
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
logging
from
nni.recoverable
import
Recoverable
_logger
=
logging
.
getLogger
(
__name__
)
class
MultiPhaseTuner
(
Recoverable
):
# pylint: disable=no-self-use,unused-argument
def
generate_parameters
(
self
,
parameter_id
,
trial_job_id
=
None
):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: identifier of the parameter (int)
"""
raise
NotImplementedError
(
'Tuner: generate_parameters not implemented'
)
def
generate_multiple_parameters
(
self
,
parameter_id_list
):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
parameter_id_list: list of int
"""
return
[
self
.
generate_parameters
(
parameter_id
)
for
parameter_id
in
parameter_id_list
]
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
trial_job_id
):
"""Invoked when a trial reports its final result. Must override.
parameter_id: identifier of the parameter (int)
parameters: object created by 'generate_parameters()'
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
raise
NotImplementedError
(
'Tuner: receive_trial_result not implemented'
)
def
receive_customized_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
trial_job_id
):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: identifier of the parameter (int)
parameters: object created by user
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
_logger
.
info
(
'Customized trial job %s ignored by tuner'
,
parameter_id
)
def
trial_end
(
self
,
parameter_id
,
success
,
trial_job_id
):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: identifier of the parameter (int)
success: True if the trial successfully completed; False if failed or terminated
trial_job_id: identifier of the trial (str)
"""
pass
def
update_search_space
(
self
,
search_space
):
"""Update the search space of tuner. Must override.
search_space: JSON object
"""
raise
NotImplementedError
(
'Tuner: update_search_space not implemented'
)
def
import_data
(
self
,
data
):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def
load_checkpoint
(
self
):
"""Load the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path
=
self
.
get_checkpoint_path
()
_logger
.
info
(
'Load checkpoint ignored by tuner, checkpoint path: %s'
%
checkpoin_path
)
def
save_checkpoint
(
self
):
"""Save the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path
=
self
.
get_checkpoint_path
()
_logger
.
info
(
'Save checkpoint ignored by tuner, checkpoint path: %s'
%
checkpoin_path
)
def
_on_exit
(
self
):
pass
def
_on_error
(
self
):
pass
def
import_data
(
self
,
data
):
pass
src/sdk/pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
View file @
1c56fea8
...
...
@@ -123,7 +123,7 @@ class NetworkMorphismTuner(Tuner):
"""
self
.
search_space
=
search_space
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""
Returns a set of trial neural architecture, as a serializable object.
...
...
@@ -152,7 +152,7 @@ class NetworkMorphismTuner(Tuner):
return
json_out
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
""" Record an observation of the objective function.
Parameters
...
...
src/sdk/pynni/nni/smac_tuner/smac_tuner.py
View file @
1c56fea8
...
...
@@ -151,7 +151,7 @@ class SMACTuner(Tuner):
else
:
self
.
logger
.
warning
(
'update search space is not supported.'
)
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""receive_trial_result
Parameters
...
...
@@ -209,7 +209,7 @@ class SMACTuner(Tuner):
converted_dict
[
key
]
=
value
return
converted_dict
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""generate one instance of hyperparameters
Parameters
...
...
@@ -232,7 +232,7 @@ class SMACTuner(Tuner):
self
.
total_data
[
parameter_id
]
=
challenger
return
self
.
convert_loguniform_categorical
(
challenger
.
get_dictionary
())
def
generate_multiple_parameters
(
self
,
parameter_id_list
):
def
generate_multiple_parameters
(
self
,
parameter_id_list
,
**
kwargs
):
"""generate mutiple instances of hyperparameters
Parameters
...
...
src/sdk/pynni/nni/tuner.py
View file @
1c56fea8
...
...
@@ -30,14 +30,14 @@ _logger = logging.getLogger(__name__)
class
Tuner
(
Recoverable
):
# pylint: disable=no-self-use,unused-argument
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
"""
raise
NotImplementedError
(
'Tuner: generate_parameters not implemented'
)
def
generate_multiple_parameters
(
self
,
parameter_id_list
):
def
generate_multiple_parameters
(
self
,
parameter_id_list
,
**
kwargs
):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
...
...
@@ -49,13 +49,13 @@ class Tuner(Recoverable):
for
parameter_id
in
parameter_id_list
:
try
:
_logger
.
debug
(
"generating param for {}"
.
format
(
parameter_id
))
res
=
self
.
generate_parameters
(
parameter_id
)
res
=
self
.
generate_parameters
(
parameter_id
,
**
kwargs
)
except
nni
.
NoMoreTrialError
:
return
result
result
.
append
(
res
)
return
result
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""Invoked when a trial reports its final result. Must override.
parameter_id: int
parameters: object created by 'generate_parameters()'
...
...
@@ -63,7 +63,7 @@ class Tuner(Recoverable):
"""
raise
NotImplementedError
(
'Tuner: receive_trial_result not implemented'
)
def
receive_customized_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_customized_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: int
parameters: object created by user
...
...
@@ -71,7 +71,7 @@ class Tuner(Recoverable):
"""
_logger
.
info
(
'Customized trial job %s ignored by tuner'
,
parameter_id
)
def
trial_end
(
self
,
parameter_id
,
success
):
def
trial_end
(
self
,
parameter_id
,
success
,
**
kwargs
):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: int
success: True if the trial successfully completed; False if failed or terminated
...
...
src/sdk/pynni/tests/test_multi_phase_tuner.py
deleted
100644 → 0
View file @
12410686
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
logging
import
random
from
io
import
BytesIO
import
nni
import
nni.protocol
from
nni.protocol
import
CommandType
,
send
,
receive
from
nni.multi_phase.multi_phase_tuner
import
MultiPhaseTuner
from
nni.multi_phase.multi_phase_dispatcher
import
MultiPhaseMsgDispatcher
from
unittest
import
TestCase
,
main
class
NaiveMultiPhaseTuner
(
MultiPhaseTuner
):
'''
supports only choices
'''
def
__init__
(
self
):
self
.
search_space
=
None
def
generate_parameters
(
self
,
parameter_id
,
trial_job_id
=
None
):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
"""
generated_parameters
=
{}
if
self
.
search_space
is
None
:
raise
AssertionError
(
'Search space not specified'
)
for
k
in
self
.
search_space
:
param
=
self
.
search_space
[
k
]
if
not
param
[
'_type'
]
==
'choice'
:
raise
ValueError
(
'Only choice type is supported'
)
param_values
=
param
[
'_value'
]
generated_parameters
[
k
]
=
param_values
[
random
.
randint
(
0
,
len
(
param_values
)
-
1
)]
logging
.
getLogger
(
__name__
).
debug
(
generated_parameters
)
return
generated_parameters
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
trial_job_id
):
logging
.
getLogger
(
__name__
).
debug
(
'receive_trial_result: {},{},{},{}'
.
format
(
parameter_id
,
parameters
,
value
,
trial_job_id
))
def
receive_customized_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
trial_job_id
):
pass
def
update_search_space
(
self
,
search_space
):
self
.
search_space
=
search_space
_in_buf
=
BytesIO
()
_out_buf
=
BytesIO
()
def
_reverse_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
nni
.
protocol
.
_out_file
=
_in_buf
nni
.
protocol
.
_in_file
=
_out_buf
def
_restore_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
nni
.
protocol
.
_in_file
=
_in_buf
nni
.
protocol
.
_out_file
=
_out_buf
def
_test_tuner
():
_reverse_io
()
# now we are sending to Tuner's incoming stream
send
(
CommandType
.
UpdateSearchSpace
,
"{
\"
learning_rate
\"
: {
\"
_value
\"
: [0.0001, 0.001, 0.002, 0.005, 0.01],
\"
_type
\"
:
\"
choice
\"
},
\"
optimizer
\"
: {
\"
_value
\"
: [
\"
Adam
\"
,
\"
SGD
\"
],
\"
_type
\"
:
\"
choice
\"
}}"
)
send
(
CommandType
.
RequestTrialJobs
,
'2'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":0,"type":"PERIODICAL","value":10,"trial_job_id":"abc"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":1,"type":"FINAL","value":11,"trial_job_id":"abc"}'
)
send
(
CommandType
.
AddCustomizedTrialJob
,
'{"param":-1}'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":2,"type":"FINAL","value":22,"trial_job_id":"abc"}'
)
send
(
CommandType
.
RequestTrialJobs
,
'1'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"abc"}'
)
_restore_io
()
tuner
=
NaiveMultiPhaseTuner
()
dispatcher
=
MultiPhaseMsgDispatcher
(
tuner
)
dispatcher
.
run
()
_reverse_io
()
# now we are receiving from Tuner's outgoing stream
command
,
data
=
receive
()
# this one is customized
print
(
command
,
data
)
class
MultiPhaseTestCase
(
TestCase
):
def
test_tuner
(
self
):
_test_tuner
()
if
__name__
==
'__main__'
:
main
()
\ No newline at end of file
src/sdk/pynni/tests/test_tuner.py
View file @
1c56fea8
...
...
@@ -35,7 +35,7 @@ class NaiveTuner(Tuner):
self
.
trial_results
=
[
]
self
.
search_space
=
None
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
# report Tuner's internal states to generated parameters,
# so we don't need to pause the main loop
self
.
param
+=
2
...
...
@@ -45,7 +45,7 @@ class NaiveTuner(Tuner):
'search_space'
:
self
.
search_space
}
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
reward
=
extract_scalar_reward
(
value
)
self
.
trial_results
.
append
((
parameter_id
,
parameters
[
'param'
],
reward
,
False
))
...
...
@@ -103,11 +103,9 @@ class TunerTestCase(TestCase):
command
,
data
=
receive
()
# this one is customized
data
=
json
.
loads
(
data
)
self
.
assertIs
(
command
,
CommandType
.
NewTrialJob
)
self
.
assertEqual
(
data
,
{
'parameter_id'
:
2
,
'parameter_source'
:
'customized'
,
'parameters'
:
{
'param'
:
-
1
}
})
self
.
assertEqual
(
data
[
'parameter_id'
],
2
)
self
.
assertEqual
(
data
[
'parameter_source'
],
'customized'
)
self
.
assertEqual
(
data
[
'parameters'
],
{
'param'
:
-
1
})
self
.
_assert_params
(
3
,
6
,
[[
1
,
4
,
11
,
False
],
[
2
,
-
1
,
22
,
True
]],
{
'name'
:
'SS0'
})
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment