Unverified Commit 24fa4619 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Merge pull request #2081 from microsoft/v1.4

merge V1.4 back to master
parents aaaa2756 8ff039c2
...@@ -380,6 +380,8 @@ class Hyperband(MsgDispatcherBase): ...@@ -380,6 +380,8 @@ class Hyperband(MsgDispatcherBase):
ValueError ValueError
Data type not supported Data type not supported
""" """
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled() assert multi_phase_enabled()
assert data['trial_job_id'] is not None assert data['trial_job_id'] is not None
......
...@@ -113,6 +113,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -113,6 +113,8 @@ class MsgDispatcher(MsgDispatcherBase):
"""Import additional data for tuning """Import additional data for tuning
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value' data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
""" """
for entry in data:
entry['value'] = json_tricks.loads(entry['value'])
self.tuner.import_data(data) self.tuner.import_data(data)
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
...@@ -127,6 +129,9 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -127,6 +129,9 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result() - 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'} - 'type': report type, support {'FINAL', 'PERIODICAL'}
""" """
# metrics value is dumped as json string in trial, so we need to decode it here
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.FINAL: if data['type'] == MetricType.FINAL:
self._handle_final_metric_data(data) self._handle_final_metric_data(data)
elif data['type'] == MetricType.PERIODICAL: elif data['type'] == MetricType.PERIODICAL:
......
...@@ -116,8 +116,6 @@ class AverageMeter: ...@@ -116,8 +116,6 @@ class AverageMeter:
n : int n : int
The weight of the new value. The weight of the new value.
""" """
if not isinstance(val, float) and not isinstance(val, int):
_logger.warning("Values passed to AverageMeter must be number, not %s.", type(val))
self.val = val self.val = val
self.sum += val * n self.sum += val * n
self.count += n self.count += n
......
...@@ -33,4 +33,7 @@ def init_params(params): ...@@ -33,4 +33,7 @@ def init_params(params):
_params = copy.deepcopy(params) _params = copy.deepcopy(params)
def get_last_metric(): def get_last_metric():
return json_tricks.loads(_last_metric) metrics = json_tricks.loads(_last_metric)
metrics['value'] = json_tricks.loads(metrics['value'])
return metrics
...@@ -114,7 +114,7 @@ def report_intermediate_result(metric): ...@@ -114,7 +114,7 @@ def report_intermediate_result(metric):
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL', 'type': 'PERIODICAL',
'sequence': _intermediate_seq, 'sequence': _intermediate_seq,
'value': metric 'value': to_json(metric)
}) })
_intermediate_seq += 1 _intermediate_seq += 1
platform.send_metric(metric) platform.send_metric(metric)
...@@ -135,6 +135,6 @@ def report_final_result(metric): ...@@ -135,6 +135,6 @@ def report_final_result(metric):
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL', 'type': 'FINAL',
'sequence': 0, 'sequence': 0,
'value': metric 'value': to_json(metric)
}) })
platform.send_metric(metric) platform.send_metric(metric)
...@@ -47,9 +47,9 @@ def _restore_io(): ...@@ -47,9 +47,9 @@ def _restore_io():
class AssessorTestCase(TestCase): class AssessorTestCase(TestCase):
def test_assessor(self): def test_assessor(self):
_reverse_io() _reverse_io()
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":2}') send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":2}') send(CommandType.ReportMetricData, '{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":3}') send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}')
send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED"}') send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED"}')
send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED"}') send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED"}')
send(CommandType.NewTrialJob, 'null') send(CommandType.NewTrialJob, 'null')
......
...@@ -59,8 +59,8 @@ class MsgDispatcherTestCase(TestCase): ...@@ -59,8 +59,8 @@ class MsgDispatcherTestCase(TestCase):
def test_msg_dispatcher(self): def test_msg_dispatcher(self):
_reverse_io() # now we are sending to Tuner's incoming stream _reverse_io() # now we are sending to Tuner's incoming stream
send(CommandType.RequestTrialJobs, '2') send(CommandType.RequestTrialJobs, '2')
send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":10}') send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":"10"}')
send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":11}') send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":"11"}')
send(CommandType.UpdateSearchSpace, '{"name":"SS0"}') send(CommandType.UpdateSearchSpace, '{"name":"SS0"}')
send(CommandType.RequestTrialJobs, '1') send(CommandType.RequestTrialJobs, '1')
send(CommandType.KillTrialJob, 'null') send(CommandType.KillTrialJob, 'null')
......
...@@ -117,7 +117,6 @@ class ChangeColumnComponent extends React.Component<ChangeColumnProps, ChangeCol ...@@ -117,7 +117,6 @@ class ChangeColumnComponent extends React.Component<ChangeColumnProps, ChangeCol
}); });
return ( return (
<div> <div>
<div>Hello</div>
<Dialog <Dialog
hidden={isHideDialog} // required field! hidden={isHideDialog} // required field!
dialogContentProps={{ dialogContentProps={{
...@@ -130,7 +129,7 @@ class ChangeColumnComponent extends React.Component<ChangeColumnProps, ChangeCol ...@@ -130,7 +129,7 @@ class ChangeColumnComponent extends React.Component<ChangeColumnProps, ChangeCol
styles: { main: { maxWidth: 450 } } styles: { main: { maxWidth: 450 } }
}} }}
> >
<div> <div className="columns-height">
{renderOptions.map(item => { {renderOptions.map(item => {
return <Checkbox key={item.label} {...item} styles={{ root: { marginBottom: 8 } }} /> return <Checkbox key={item.label} {...item} styles={{ root: { marginBottom: 8 } }} />
})} })}
......
...@@ -172,7 +172,7 @@ class NavCon extends React.Component<NavProps, NavState> { ...@@ -172,7 +172,7 @@ class NavCon extends React.Component<NavProps, NavState> {
/> />
<CommandBarButton <CommandBarButton
iconProps={infoIconAbout} iconProps={infoIconAbout}
text="about" text="About"
menuProps={aboutProps} menuProps={aboutProps}
/> />
</Stack> </Stack>
......
...@@ -56,7 +56,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -56,7 +56,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
return; return;
} }
switch (this.state.searchType) { switch (this.state.searchType) {
case 'id': case 'Id':
filter = (trial): boolean => trial.info.id.toUpperCase().includes(targetValue.toUpperCase()); filter = (trial): boolean => trial.info.id.toUpperCase().includes(targetValue.toUpperCase());
break; break;
case 'Trial No.': case 'Trial No.':
...@@ -65,7 +65,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -65,7 +65,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
case 'Status': case 'Status':
filter = (trial): boolean => trial.info.status.toUpperCase().includes(targetValue.toUpperCase()); filter = (trial): boolean => trial.info.status.toUpperCase().includes(targetValue.toUpperCase());
break; break;
case 'parameters': case 'Parameters':
// TODO: support filters like `x: 2` (instead of `"x": 2`) // TODO: support filters like `x: 2` (instead of `"x": 2`)
filter = (trial): boolean => JSON.stringify(trial.info.hyperParameters, null, 4).includes(targetValue); filter = (trial): boolean => JSON.stringify(trial.info.hyperParameters, null, 4).includes(targetValue);
break; break;
......
...@@ -54,7 +54,7 @@ interface TableListState { ...@@ -54,7 +54,7 @@ interface TableListState {
isShowCustomizedModal: boolean; isShowCustomizedModal: boolean;
copyTrialId: string; // user copy trial to submit a new customized trial copyTrialId: string; // user copy trial to submit a new customized trial
isCalloutVisible: boolean; // kill job button callout [kill or not kill job window] isCalloutVisible: boolean; // kill job button callout [kill or not kill job window]
intermediateKeys: string[]; // intermeidate modal: which key is choosed. intermediateKey: string; // intermeidate modal: which key is choosed.
isExpand: boolean; isExpand: boolean;
modalIntermediateWidth: number; modalIntermediateWidth: number;
modalIntermediateHeight: number; modalIntermediateHeight: number;
...@@ -86,7 +86,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -86,7 +86,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
isShowCustomizedModal: false, isShowCustomizedModal: false,
isCalloutVisible: false, isCalloutVisible: false,
copyTrialId: '', copyTrialId: '',
intermediateKeys: ['default'], intermediateKey: 'default',
isExpand: false, isExpand: false,
modalIntermediateWidth: window.innerWidth, modalIntermediateWidth: window.innerWidth,
modalIntermediateHeight: window.innerHeight, modalIntermediateHeight: window.innerHeight,
...@@ -128,7 +128,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -128,7 +128,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
name: 'Default metric', name: 'Default metric',
className: 'leftTitle', className: 'leftTitle',
key: 'accuracy', key: 'accuracy',
fieldName: 'accuracy', fieldName: 'latestAccuracy',
minWidth: 200, minWidth: 200,
maxWidth: 300, maxWidth: 300,
isResizable: true, isResizable: true,
...@@ -294,7 +294,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -294,7 +294,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
const intermediate = intermediateGraphOption(intermediateArr, intermediateId); const intermediate = intermediateGraphOption(intermediateArr, intermediateId);
// re-render // re-render
this.setState({ this.setState({
intermediateKeys: [value], intermediateKey: value,
intermediateOption: intermediate intermediateOption: intermediate
}); });
} }
...@@ -388,29 +388,27 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -388,29 +388,27 @@ class TableList extends React.Component<TableListProps, TableListState> {
parameterStr.push(`${value} (search space)`); parameterStr.push(`${value} (search space)`);
}); });
} }
let allColumnList = COLUMNPro; // eslint-disable-line @typescript-eslint/no-unused-vars let allColumnList = COLUMNPro.concat(parameterStr);
allColumnList = COLUMNPro.concat(parameterStr);
// only succeed trials have final keys // only succeed trials have final keys
if (tableSource.filter(record => record.status === 'SUCCEEDED').length >= 1) { if (tableSource.filter(record => record.status === 'SUCCEEDED').length >= 1) {
const temp = tableSource.filter(record => record.status === 'SUCCEEDED')[0].accuracy; const temp = tableSource.filter(record => record.status === 'SUCCEEDED')[0].accDictionary;
if (temp !== undefined && typeof temp === 'object') { if (temp !== undefined && typeof temp === 'object') {
if (!isNaN(temp)) { // concat default column and finalkeys
// concat default column and finalkeys const item = Object.keys(temp);
const item = Object.keys(temp); // item: ['default', 'other-keys', 'maybe loss']
// item: ['default', 'other-keys', 'maybe loss'] if (item.length > 1) {
if (item.length > 1) { const want: string[] = [];
const want: string[] = []; item.forEach(value => {
item.forEach(value => { if (value !== 'default') {
if (value !== 'default') { want.push(value);
want.push(value); }
} });
}); allColumnList = allColumnList.concat(want);
allColumnList = COLUMNPro.concat(want);
}
} }
} }
} }
return allColumnList; return allColumnList;
} }
...@@ -522,8 +520,22 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -522,8 +520,22 @@ class TableList extends React.Component<TableListProps, TableListState> {
}); });
break; break;
default: default:
// FIXME showColumn.push({
alert('Unexpected column type'); name: item,
key: item,
fieldName: item,
minWidth: 100,
onRender: (record: TableRecord) => {
const accDictionary = record.accDictionary;
let other = '';
if (accDictionary !== undefined) {
other = accDictionary[item].toString();
}
return (
<div>{other}</div>
);
}
});
} }
} }
return showColumn; return showColumn;
...@@ -534,19 +546,22 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -534,19 +546,22 @@ class TableList extends React.Component<TableListProps, TableListState> {
} }
UNSAFE_componentWillReceiveProps(nextProps: TableListProps): void { UNSAFE_componentWillReceiveProps(nextProps: TableListProps): void {
const { columnList } = nextProps; const { columnList, tableSource } = nextProps;
this.setState({ tableColumns: this.initTableColumnList(columnList) }); this.setState({
tableSourceForSort: tableSource,
tableColumns: this.initTableColumnList(columnList),
allColumnList: this.getAllColumnKeys()
});
} }
render(): React.ReactNode { render(): React.ReactNode {
const { intermediateKeys, modalIntermediateWidth, modalIntermediateHeight, const { intermediateKey, modalIntermediateWidth, modalIntermediateHeight,
tableColumns, allColumnList, isShowColumn, modalVisible, tableColumns, allColumnList, isShowColumn, modalVisible,
selectRows, isShowCompareModal, intermediateOtherKeys, selectRows, isShowCompareModal, intermediateOtherKeys,
isShowCustomizedModal, copyTrialId, intermediateOption isShowCustomizedModal, copyTrialId, intermediateOption
} = this.state; } = this.state;
const { columnList } = this.props; const { columnList } = this.props;
const tableSource: Array<TableRecord> = JSON.parse(JSON.stringify(this.state.tableSourceForSort)); const tableSource: Array<TableRecord> = JSON.parse(JSON.stringify(this.state.tableSourceForSort));
return ( return (
<Stack> <Stack>
<div id="tableList"> <div id="tableList">
...@@ -580,11 +595,10 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -580,11 +595,10 @@ class TableList extends React.Component<TableListProps, TableListState> {
{ {
intermediateOtherKeys.length > 1 intermediateOtherKeys.length > 1
? ?
<Stack className="selectKeys" styles={{ root: { width: 800 } }}> <Stack horizontalAlign="end" className="selectKeys">
<Dropdown <Dropdown
className="select" className="select"
selectedKeys={intermediateKeys} selectedKey={intermediateKey}
onChange={this.selectOtherKeys}
options={ options={
intermediateOtherKeys.map((key, item) => { intermediateOtherKeys.map((key, item) => {
return { return {
...@@ -592,7 +606,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -592,7 +606,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
}; };
}) })
} }
styles={{ dropdown: { width: 300 } }} onChange={this.selectOtherKeys}
/> />
</Stack> </Stack>
: :
......
...@@ -37,13 +37,21 @@ const convertDuration = (num: number): string => { ...@@ -37,13 +37,21 @@ const convertDuration = (num: number): string => {
return result.join(' '); return result.join(' ');
}; };
function parseMetrics(metricData: string): any {
if (metricData.includes('NaN')) {
return JSON5.parse(JSON5.parse(metricData));
} else {
return JSON.parse(JSON.parse(metricData));
}
}
// get final result value // get final result value
// draw Accuracy point graph // draw Accuracy point graph
const getFinalResult = (final?: MetricDataRecord[]): number => { const getFinalResult = (final?: MetricDataRecord[]): number => {
let acc; let acc;
let showDefault = 0; let showDefault = 0;
if (final) { if (final) {
acc = JSON.parse(final[final.length - 1].data); acc = parseMetrics(final[final.length - 1].data);
if (typeof (acc) === 'object') { if (typeof (acc) === 'object') {
if (acc.default) { if (acc.default) {
showDefault = acc.default; showDefault = acc.default;
...@@ -61,7 +69,7 @@ const getFinalResult = (final?: MetricDataRecord[]): number => { ...@@ -61,7 +69,7 @@ const getFinalResult = (final?: MetricDataRecord[]): number => {
const getFinal = (final?: MetricDataRecord[]): FinalType | undefined => { const getFinal = (final?: MetricDataRecord[]): FinalType | undefined => {
let showDefault: FinalType; let showDefault: FinalType;
if (final) { if (final) {
showDefault = JSON.parse(final[final.length - 1].data); showDefault = parseMetrics(final[final.length - 1].data);
if (typeof showDefault === 'number') { if (typeof showDefault === 'number') {
showDefault = { default: showDefault }; showDefault = { default: showDefault };
} }
...@@ -179,17 +187,14 @@ function formatTimestamp(timestamp?: number, placeholder?: string): string { ...@@ -179,17 +187,14 @@ function formatTimestamp(timestamp?: number, placeholder?: string): string {
return timestamp ? new Date(timestamp).toLocaleString('en-US') : placeholder; return timestamp ? new Date(timestamp).toLocaleString('en-US') : placeholder;
} }
function parseMetrics(metricData: string): any {
if (metricData.includes('NaN')) {
return JSON5.parse(metricData)
} else {
return JSON.parse(metricData)
}
}
function metricAccuracy(metric: MetricDataRecord): number { function metricAccuracy(metric: MetricDataRecord): number {
const data = parseMetrics(metric.data); const data = parseMetrics(metric.data);
return typeof data === 'number' ? data : NaN; // return typeof data === 'number' ? data : NaN;
if (typeof data === 'number') {
return data;
} else {
return data.default;
}
} }
function formatAccuracy(accuracy: number): string { function formatAccuracy(accuracy: number): string {
......
...@@ -23,7 +23,8 @@ interface TableRecord { ...@@ -23,7 +23,8 @@ interface TableRecord {
intermediateCount: number; intermediateCount: number;
accuracy?: number; accuracy?: number;
latestAccuracy: number | undefined; latestAccuracy: number | undefined;
formattedLatestAccuracy: string; // format (LATEST/FINAL) formattedLatestAccuracy: string; // format (LATEST/FINAL),
accDictionary: FinalType | undefined;
} }
interface SearchSpace { interface SearchSpace {
......
...@@ -53,10 +53,13 @@ class Trial implements TableObj { ...@@ -53,10 +53,13 @@ class Trial implements TableObj {
if (this.accuracy !== undefined) { if (this.accuracy !== undefined) {
return this.accuracy; return this.accuracy;
} else if (this.intermediates.length > 0) { } else if (this.intermediates.length > 0) {
// TODO: support intermeidate result is dict
const temp = this.intermediates[this.intermediates.length - 1]; const temp = this.intermediates[this.intermediates.length - 1];
if (temp !== undefined) { if (temp !== undefined) {
return parseMetrics(temp.data); if (typeof parseMetrics(temp.data) === 'object') {
return parseMetrics(temp.data).default;
} else {
return parseMetrics(temp.data);
}
} else { } else {
return undefined; return undefined;
} }
...@@ -82,9 +85,11 @@ class Trial implements TableObj { ...@@ -82,9 +85,11 @@ class Trial implements TableObj {
duration, duration,
status: this.info.status, status: this.info.status,
intermediateCount: this.intermediates.length, intermediateCount: this.intermediates.length,
accuracy: this.finalAcc, // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
accuracy: this.acc !== undefined ? JSON.parse(this.acc!.default) : undefined,
latestAccuracy: this.latestAccuracy, latestAccuracy: this.latestAccuracy,
formattedLatestAccuracy: this.formatLatestAccuracy(), formattedLatestAccuracy: this.formatLatestAccuracy(),
accDictionary: this.acc
}; };
} }
......
...@@ -27,9 +27,10 @@ ...@@ -27,9 +27,10 @@
.selectKeys{ .selectKeys{
/* intermediate result is dict, select box for keys */ /* intermediate result is dict, select box for keys */
.select{ .select{
margin-right: 12%;
}
.ms-Dropdown{
width: 120px; width: 120px;
float: right;
margin-right: 10%;
} }
} }
...@@ -46,4 +46,8 @@ ...@@ -46,4 +46,8 @@
} }
.detail-table{ .detail-table{
padding: 5px 0 0 0; padding: 5px 0 0 0;
} }
\ No newline at end of file .columns-height{
max-height: 335px;
overflow-y: scroll;
}
...@@ -56,9 +56,9 @@ def get_metric_results(metrics): ...@@ -56,9 +56,9 @@ def get_metric_results(metrics):
final_result = [] final_result = []
for metric in metrics: for metric in metrics:
if metric['type'] == 'PERIODICAL': if metric['type'] == 'PERIODICAL':
intermediate_result.append(metric['data']) intermediate_result.append(json.loads(metric['data']))
elif metric['type'] == 'FINAL': elif metric['type'] == 'FINAL':
final_result.append(metric['data']) final_result.append(json.loads(metric['data']))
print(intermediate_result, final_result) print(intermediate_result, final_result)
return [round(float(x),6) for x in intermediate_result], [round(float(x), 6) for x in final_result] return [round(float(x),6) for x in intermediate_result], [round(float(x), 6) for x in final_result]
......
...@@ -59,7 +59,7 @@ jobs: ...@@ -59,7 +59,7 @@ jobs:
displayName: 'integration test' displayName: 'integration test'
- task: SSH@0 - task: SSH@0
inputs: inputs:
sshEndpoint: remote_nni-ci-gpu-01 sshEndpoint: $(end_point)
runOptions: commands runOptions: commands
commands: python3 /tmp/nnitest/$(Build.BuildId)/test/remote_docker.py --mode stop --name $(Build.BuildId) commands: python3 /tmp/nnitest/$(Build.BuildId)/test/remote_docker.py --mode stop --name $(Build.BuildId)
displayName: 'Stop docker' displayName: 'Stop docker'
...@@ -78,17 +78,17 @@ def get_nni_installation_path(): ...@@ -78,17 +78,17 @@ def get_nni_installation_path():
print_error('Fail to find nni under python library') print_error('Fail to find nni under python library')
exit(1) exit(1)
def start_rest_server(args, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None): def start_rest_server(port, platform, mode, config_file_name, foreground=False, experiment_id=None, log_dir=None, log_level=None):
'''Run nni manager process''' '''Run nni manager process'''
if detect_port(args.port): if detect_port(port):
print_error('Port %s is used by another process, please reset the port!\n' \ print_error('Port %s is used by another process, please reset the port!\n' \
'You could use \'nnictl create --help\' to get help information' % args.port) 'You could use \'nnictl create --help\' to get help information' % port)
exit(1) exit(1)
if (platform != 'local') and detect_port(int(args.port) + 1): if (platform != 'local') and detect_port(int(port) + 1):
print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \ print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
'You could set another port to start experiment!\n' \ 'You could set another port to start experiment!\n' \
'You could use \'nnictl create --help\' to get help information' % ((int(args.port) + 1), (int(args.port) + 1))) 'You could use \'nnictl create --help\' to get help information' % ((int(port) + 1), (int(port) + 1)))
exit(1) exit(1)
print_normal('Starting restful server...') print_normal('Starting restful server...')
...@@ -99,7 +99,7 @@ def start_rest_server(args, platform, mode, config_file_name, experiment_id=None ...@@ -99,7 +99,7 @@ def start_rest_server(args, platform, mode, config_file_name, experiment_id=None
node_command = 'node' node_command = 'node'
if sys.platform == 'win32': if sys.platform == 'win32':
node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe') node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
cmds = [node_command, entry_file, '--port', str(args.port), '--mode', platform] cmds = [node_command, entry_file, '--port', str(port), '--mode', platform]
if mode == 'view': if mode == 'view':
cmds += ['--start_mode', 'resume'] cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true'] cmds += ['--readonly', 'true']
...@@ -111,7 +111,7 @@ def start_rest_server(args, platform, mode, config_file_name, experiment_id=None ...@@ -111,7 +111,7 @@ def start_rest_server(args, platform, mode, config_file_name, experiment_id=None
cmds += ['--log_level', log_level] cmds += ['--log_level', log_level]
if mode in ['resume', 'view']: if mode in ['resume', 'view']:
cmds += ['--experiment_id', experiment_id] cmds += ['--experiment_id', experiment_id]
if args.foreground: if foreground:
cmds += ['--foreground', 'true'] cmds += ['--foreground', 'true']
stdout_full_path, stderr_full_path = get_log_path(config_file_name) stdout_full_path, stderr_full_path = get_log_path(config_file_name)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
...@@ -122,12 +122,12 @@ def start_rest_server(args, platform, mode, config_file_name, experiment_id=None ...@@ -122,12 +122,12 @@ def start_rest_server(args, platform, mode, config_file_name, experiment_id=None
stderr_file.write(log_header) stderr_file.write(log_header)
if sys.platform == 'win32': if sys.platform == 'win32':
from subprocess import CREATE_NEW_PROCESS_GROUP from subprocess import CREATE_NEW_PROCESS_GROUP
if args.foreground: if foreground:
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=STDOUT, creationflags=CREATE_NEW_PROCESS_GROUP) process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=STDOUT, creationflags=CREATE_NEW_PROCESS_GROUP)
else: else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP) process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP)
else: else:
if args.foreground: if foreground:
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE) process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE)
else: else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file) process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
...@@ -428,12 +428,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -428,12 +428,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
#view experiment mode do not need debug function, when view an experiment, there will be no new logs created #view experiment mode do not need debug function, when view an experiment, there will be no new logs created
foreground = False
if mode != 'view': if mode != 'view':
foreground = args.foreground
if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True): if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
log_level = 'debug' log_level = 'debug'
# start rest server # start rest server
rest_process, start_time = start_rest_server(args, experiment_config['trainingServicePlatform'], \ rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, config_file_name, experiment_id, log_dir, log_level) mode, config_file_name, foreground, experiment_id, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid) nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation # Deal with annotation
if experiment_config.get('useAnnotation'): if experiment_config.get('useAnnotation'):
...@@ -501,7 +503,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -501,7 +503,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
experiment_config['experimentName']) experiment_config['experimentName'])
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
if args.foreground: if mode != 'view' and args.foreground:
try: try:
while True: while True:
log_content = rest_process.stdout.readline().strip().decode('utf-8') log_content = rest_process.stdout.readline().strip().decode('utf-8')
......
...@@ -63,10 +63,10 @@ def parse_args(): ...@@ -63,10 +63,10 @@ def parse_args():
parser_resume.set_defaults(func=resume_experiment) parser_resume.set_defaults(func=resume_experiment)
# parse view command # parse view command
parser_resume = subparsers.add_parser('view', help='view a stopped experiment') parser_view = subparsers.add_parser('view', help='view a stopped experiment')
parser_resume.add_argument('id', nargs='?', help='The id of the experiment you want to view') parser_view.add_argument('id', nargs='?', help='The id of the experiment you want to view')
parser_resume.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server') parser_view.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_resume.set_defaults(func=view_experiment) parser_view.set_defaults(func=view_experiment)
# parse update command # parse update command
parser_updater = subparsers.add_parser('update', help='update the experiment') parser_updater = subparsers.add_parser('update', help='update the experiment')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment