Unverified Commit 7003603c authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Use correct Python interpreter in NNI manager (#4840)

parent 314255d9
...@@ -41,6 +41,7 @@ class NniManagerArgs: ...@@ -41,6 +41,7 @@ class NniManagerArgs:
foreground: bool = False foreground: bool = False
url_prefix: str | None = None # leading and trailing "/" must be stripped url_prefix: str | None = None # leading and trailing "/" must be stripped
tuner_command_channel: str | None = None tuner_command_channel: str | None = None
python_interpreter: str
def __init__(self, def __init__(self,
action: Literal['create', 'resume', 'view'], action: Literal['create', 'resume', 'view'],
...@@ -60,6 +61,7 @@ class NniManagerArgs: ...@@ -60,6 +61,7 @@ class NniManagerArgs:
# see "ts/nni_manager/common/globals/arguments.ts" for details # see "ts/nni_manager/common/globals/arguments.ts" for details
self.experiments_directory = cast(str, config.experiment_working_directory) self.experiments_directory = cast(str, config.experiment_working_directory)
self.tuner_command_channel = tuner_command_channel self.tuner_command_channel = tuner_command_channel
self.python_interpreter = sys.executable
if isinstance(config.training_service, list): if isinstance(config.training_service, list):
self.mode = 'hybrid' self.mode = 'hybrid'
......
...@@ -29,6 +29,7 @@ export interface NniManagerArgs { ...@@ -29,6 +29,7 @@ export interface NniManagerArgs {
readonly foreground: boolean; readonly foreground: boolean;
readonly urlPrefix: string; // leading and trailing "/" must be stripped readonly urlPrefix: string; // leading and trailing "/" must be stripped
readonly tunerCommandChannel: string | null; readonly tunerCommandChannel: string | null;
readonly pythonInterpreter: string;
// these are planned to be removed // these are planned to be removed
readonly mode: string; readonly mode: string;
...@@ -85,6 +86,10 @@ const yargsOptions = { ...@@ -85,6 +86,10 @@ const yargsOptions = {
default: null, default: null,
type: 'string' type: 'string'
}, },
pythonInterpreter: {
demandOption: true,
type: 'string'
},
mode: { mode: {
default: '', default: '',
......
...@@ -49,6 +49,7 @@ export function resetGlobals(): void { ...@@ -49,6 +49,7 @@ export function resetGlobals(): void {
foreground: false, foreground: false,
urlPrefix: '', urlPrefix: '',
tunerCommandChannel: null, tunerCommandChannel: null,
pythonInterpreter: 'python',
mode: 'unittest' mode: 'unittest'
}; };
const paths = createPaths(args); const paths = createPaths(args);
......
...@@ -2,14 +2,13 @@ ...@@ -2,14 +2,13 @@
// Licensed under the MIT license. // Licensed under the MIT license.
import { spawn } from 'child_process'; import { spawn } from 'child_process';
import globals from './globals';
import { Logger, getLogger } from './log'; import { Logger, getLogger } from './log';
const logger: Logger = getLogger('pythonScript'); const logger: Logger = getLogger('pythonScript');
const python: string = process.platform === 'win32' ? 'python.exe' : 'python3';
export async function runPythonScript(script: string, logTag?: string): Promise<string> { export async function runPythonScript(script: string, logTag?: string): Promise<string> {
const proc = spawn(python, [ '-c', script ]); const proc = spawn(globals.args.pythonInterpreter, [ '-c', script ]);
let stdout: string = ''; let stdout: string = '';
let stderr: string = ''; let stderr: string = '';
......
...@@ -98,14 +98,6 @@ function randomSelect<T>(a: T[]): T { ...@@ -98,14 +98,6 @@ function randomSelect<T>(a: T[]): T {
return a[Math.floor(Math.random() * a.length)]; return a[Math.floor(Math.random() * a.length)];
} }
function getCmdPy(): string {
let cmd = 'python3';
if (process.platform === 'win32') {
cmd = 'python';
}
return cmd;
}
/** /**
* Generate command line to start automl algorithm(s), * Generate command line to start automl algorithm(s),
* either start advisor or start a process which runs tuner and assessor * either start advisor or start a process which runs tuner and assessor
...@@ -116,7 +108,7 @@ function getCmdPy(): string { ...@@ -116,7 +108,7 @@ function getCmdPy(): string {
function getMsgDispatcherCommand(expParams: ExperimentConfig): string { function getMsgDispatcherCommand(expParams: ExperimentConfig): string {
const clonedParams = Object.assign({}, expParams); const clonedParams = Object.assign({}, expParams);
delete clonedParams.searchSpace; delete clonedParams.searchSpace;
return `${getCmdPy()} -m nni --exp_params ${Buffer.from(JSON.stringify(clonedParams)).toString('base64')}`; return `${globals.args.pythonInterpreter} -m nni --exp_params ${Buffer.from(JSON.stringify(clonedParams)).toString('base64')}`;
} }
/** /**
...@@ -388,5 +380,5 @@ export function importModule(modulePath: string): any { ...@@ -388,5 +380,5 @@ export function importModule(modulePath: string): any {
export { export {
countFilesRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, countFilesRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, getFreePort, isPortOpen, getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, getFreePort, isPortOpen,
mkDirP, mkDirPSync, delay, prepareUnitTest, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine mkDirP, mkDirPSync, delay, prepareUnitTest, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getTunerProc, isAlive, killPid, getNewLine
}; };
...@@ -11,6 +11,7 @@ import { getLogger, Logger } from '../common/log'; ...@@ -11,6 +11,7 @@ import { getLogger, Logger } from '../common/log';
import { getTunerProc, isAlive, uniqueString, mkDirPSync, getFreePort } from '../common/utils'; import { getTunerProc, isAlive, uniqueString, mkDirPSync, getFreePort } from '../common/utils';
import { Manager } from '../common/manager'; import { Manager } from '../common/manager';
import { TensorboardParams, TensorboardTaskStatus, TensorboardTaskInfo, TensorboardManager } from '../common/tensorboardManager'; import { TensorboardParams, TensorboardTaskStatus, TensorboardTaskInfo, TensorboardManager } from '../common/tensorboardManager';
import globals from 'common/globals';
class TensorboardTaskDetail implements TensorboardTaskInfo { class TensorboardTaskDetail implements TensorboardTaskInfo {
public id: string; public id: string;
...@@ -112,7 +113,7 @@ class NNITensorboardManager implements TensorboardManager { ...@@ -112,7 +113,7 @@ class NNITensorboardManager implements TensorboardManager {
} }
private setTensorboardVersion(): void { private setTensorboardVersion(): void {
let command = `python3 -c 'import tensorboard ; print(tensorboard.__version__)' 2>&1`; let command = `${globals.args.pythonInterpreter} -c 'import tensorboard ; print(tensorboard.__version__)' 2>&1`;
if (process.platform === 'win32') { if (process.platform === 'win32') {
command = `python -c "import tensorboard ; print(tensorboard.__version__)" 2>&1`; command = `python -c "import tensorboard ; print(tensorboard.__version__)" 2>&1`;
} }
......
...@@ -5,7 +5,15 @@ import assert from 'assert/strict'; ...@@ -5,7 +5,15 @@ import assert from 'assert/strict';
import { parseArgs } from 'common/globals/arguments'; import { parseArgs } from 'common/globals/arguments';
const command = '--port 80 --experiment-id ID --action resume --experiments-directory DIR --log-level error'; const command = [
'--port 80',
'--experiment-id ID',
'--action resume',
'--experiments-directory DIR',
'--log-level error',
'--python-interpreter python',
].join(' ');
const expected = { const expected = {
port: 80, port: 80,
experimentId: 'ID', experimentId: 'ID',
...@@ -15,6 +23,7 @@ const expected = { ...@@ -15,6 +23,7 @@ const expected = {
foreground: false, foreground: false,
urlPrefix: '', urlPrefix: '',
tunerCommandChannel: null, tunerCommandChannel: null,
pythonInterpreter: 'python',
mode: '', mode: '',
}; };
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import * as assert from 'assert'; import * as assert from 'assert';
import { ChildProcess, spawn, StdioOptions } from 'child_process'; import { ChildProcess, spawn, StdioOptions } from 'child_process';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { cleanupUnitTest, prepareUnitTest, getTunerProc, getCmdPy } from '../../common/utils'; import { cleanupUnitTest, prepareUnitTest, getTunerProc } from '../../common/utils';
import * as CommandType from '../../core/commands'; import * as CommandType from '../../core/commands';
import { createDispatcherInterface, IpcInterface } from '../../core/ipcInterface'; import { createDispatcherInterface, IpcInterface } from '../../core/ipcInterface';
import { NNIError } from '../../common/errors'; import { NNIError } from '../../common/errors';
...@@ -22,7 +22,7 @@ async function runProcess(): Promise<Error | null> { ...@@ -22,7 +22,7 @@ async function runProcess(): Promise<Error | null> {
// create fake assessor process // create fake assessor process
const stdio: StdioOptions = ['ignore', 'pipe', process.stderr, 'pipe', 'pipe']; const stdio: StdioOptions = ['ignore', 'pipe', process.stderr, 'pipe', 'pipe'];
const command: string = getCmdPy() + ' assessor.py'; const command: string = 'python assessor.py';
const proc: ChildProcess = getTunerProc(command, stdio, 'core/test', process.env); const proc: ChildProcess = getTunerProc(command, stdio, 'core/test', process.env);
// record its sent/received commands on exit // record its sent/received commands on exit
proc.on('error', (error: Error): void => { deferred.resolve(error); }); proc.on('error', (error: Error): void => { deferred.resolve(error); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment