Unverified Commit 143c6615 authored by Chi Song's avatar Chi Song Committed by GitHub
Browse files

Reusable environment support GPU scheduler, add test cases and refactoring. (#2627)

parent 8a20c348
...@@ -25,7 +25,7 @@ export class TrialDetail implements TrialJobDetail { ...@@ -25,7 +25,7 @@ export class TrialDetail implements TrialJobDetail {
// it's used to aggregate node status for multiple node trial // it's used to aggregate node status for multiple node trial
public nodes: Map<string, NodeInfomation>; public nodes: Map<string, NodeInfomation>;
// assigned GPUs for multi-trial scheduled. // assigned GPUs for multi-trial scheduled.
public assignedGpus: GPUInfo[] = []; public assignedGpus: GPUInfo[] | undefined;
public readonly TRIAL_METADATA_DIR = ".nni"; public readonly TRIAL_METADATA_DIR = ".nni";
......
...@@ -262,6 +262,10 @@ ...@@ -262,6 +262,10 @@
version "2.3.1" version "2.3.1"
resolved "https://registry.yarnpkg.com/@types/js-base64/-/js-base64-2.3.1.tgz#c39f14f129408a3d96a1105a650d8b2b6eeb4168" resolved "https://registry.yarnpkg.com/@types/js-base64/-/js-base64-2.3.1.tgz#c39f14f129408a3d96a1105a650d8b2b6eeb4168"
"@types/js-yaml@^3.12.5":
version "3.12.5"
resolved "https://registry.yarnpkg.com/@types/js-yaml/-/js-yaml-3.12.5.tgz#136d5e6a57a931e1cce6f9d8126aa98a9c92a6bb"
"@types/json-schema@^7.0.3": "@types/json-schema@^7.0.3":
version "7.0.3" version "7.0.3"
resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.3.tgz#bdfd69d61e464dcc81b25159c270d75a73c1a636" resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.3.tgz#bdfd69d61e464dcc81b25159c270d75a73c1a636"
...@@ -277,7 +281,6 @@ ...@@ -277,7 +281,6 @@
"@types/minipass@*": "@types/minipass@*":
version "2.2.0" version "2.2.0"
resolved "https://registry.yarnpkg.com/@types/minipass/-/minipass-2.2.0.tgz#51ad404e8eb1fa961f75ec61205796807b6f9651" resolved "https://registry.yarnpkg.com/@types/minipass/-/minipass-2.2.0.tgz#51ad404e8eb1fa961f75ec61205796807b6f9651"
integrity sha512-wuzZksN4w4kyfoOv/dlpov4NOunwutLA/q7uc00xU02ZyUY+aoM5PWIXEKBMnm0NHd4a+N71BMjq+x7+2Af1fg==
dependencies: dependencies:
"@types/node" "*" "@types/node" "*"
...@@ -430,7 +433,6 @@ ...@@ -430,7 +433,6 @@
"@types/tar@^4.0.3": "@types/tar@^4.0.3":
version "4.0.3" version "4.0.3"
resolved "https://registry.yarnpkg.com/@types/tar/-/tar-4.0.3.tgz#e2cce0b8ff4f285293243f5971bd7199176ac489" resolved "https://registry.yarnpkg.com/@types/tar/-/tar-4.0.3.tgz#e2cce0b8ff4f285293243f5971bd7199176ac489"
integrity sha512-Z7AVMMlkI8NTWF0qGhC4QIX0zkV/+y0J8x7b/RsHrN0310+YNjoJd8UrApCiGBCWtKjxS9QhNqLi2UJNToh5hA==
dependencies: dependencies:
"@types/minipass" "*" "@types/minipass" "*"
"@types/node" "*" "@types/node" "*"
...@@ -1017,7 +1019,6 @@ chownr@^1.1.2, chownr@^1.1.3: ...@@ -1017,7 +1019,6 @@ chownr@^1.1.2, chownr@^1.1.3:
chownr@^2.0.0: chownr@^2.0.0:
version "2.0.0" version "2.0.0"
resolved "https://registry.yarnpkg.com/chownr/-/chownr-2.0.0.tgz#15bfbe53d2eab4cf70f18a8cd68ebe5b3cb1dece" resolved "https://registry.yarnpkg.com/chownr/-/chownr-2.0.0.tgz#15bfbe53d2eab4cf70f18a8cd68ebe5b3cb1dece"
integrity sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ==
ci-info@^1.5.0: ci-info@^1.5.0:
version "1.6.0" version "1.6.0"
...@@ -1912,7 +1913,6 @@ fs-minipass@^1.2.5: ...@@ -1912,7 +1913,6 @@ fs-minipass@^1.2.5:
fs-minipass@^2.0.0: fs-minipass@^2.0.0:
version "2.1.0" version "2.1.0"
resolved "https://registry.yarnpkg.com/fs-minipass/-/fs-minipass-2.1.0.tgz#7f5036fdbf12c63c169190cbe4199c852271f9fb" resolved "https://registry.yarnpkg.com/fs-minipass/-/fs-minipass-2.1.0.tgz#7f5036fdbf12c63c169190cbe4199c852271f9fb"
integrity sha512-V/JgOLFCS+R6Vcq0slCuaeWEdNC3ouDlJMNIsacH2VtALiu9mV4LPrHc5cDl8k5aw6J8jwgWWpiTo5RYhmIzvg==
dependencies: dependencies:
minipass "^3.0.0" minipass "^3.0.0"
...@@ -2331,7 +2331,6 @@ ignore@^4.0.6: ...@@ -2331,7 +2331,6 @@ ignore@^4.0.6:
ignore@^5.1.4: ignore@^5.1.4:
version "5.1.4" version "5.1.4"
resolved "https://registry.yarnpkg.com/ignore/-/ignore-5.1.4.tgz#84b7b3dbe64552b6ef0eca99f6743dbec6d97adf" resolved "https://registry.yarnpkg.com/ignore/-/ignore-5.1.4.tgz#84b7b3dbe64552b6ef0eca99f6743dbec6d97adf"
integrity sha512-MzbUSahkTW1u7JpKKjY7LCARd1fU5W2rLdxlM4kdkayuCwZImjkpluF9CM1aLewYJguPDqewLam18Y6AU69A8A==
import-fresh@^3.0.0: import-fresh@^3.0.0:
version "3.2.1" version "3.2.1"
...@@ -2650,7 +2649,6 @@ istanbul-lib-source-maps@^4.0.0: ...@@ -2650,7 +2649,6 @@ istanbul-lib-source-maps@^4.0.0:
istanbul-reports@^3.0.2: istanbul-reports@^3.0.2:
version "3.0.2" version "3.0.2"
resolved "https://registry.yarnpkg.com/istanbul-reports/-/istanbul-reports-3.0.2.tgz#d593210e5000683750cb09fc0644e4b6e27fd53b" resolved "https://registry.yarnpkg.com/istanbul-reports/-/istanbul-reports-3.0.2.tgz#d593210e5000683750cb09fc0644e4b6e27fd53b"
integrity sha512-9tZvz7AiR3PEDNGiV9vIouQ/EAcqMXFmkcA1CDFTwOB98OZVDL0PH9glHotf5Ugp6GCOTypfzGWI/OqjWNCRUw==
dependencies: dependencies:
html-escaper "^2.0.0" html-escaper "^2.0.0"
istanbul-lib-report "^3.0.0" istanbul-lib-report "^3.0.0"
...@@ -3193,7 +3191,6 @@ minipass@^2.3.5, minipass@^2.8.6, minipass@^2.9.0: ...@@ -3193,7 +3191,6 @@ minipass@^2.3.5, minipass@^2.8.6, minipass@^2.9.0:
minipass@^3.0.0: minipass@^3.0.0:
version "3.1.3" version "3.1.3"
resolved "https://registry.yarnpkg.com/minipass/-/minipass-3.1.3.tgz#7d42ff1f39635482e15f9cdb53184deebd5815fd" resolved "https://registry.yarnpkg.com/minipass/-/minipass-3.1.3.tgz#7d42ff1f39635482e15f9cdb53184deebd5815fd"
integrity sha512-Mgd2GdMVzY+x3IJ+oHnVM+KG3lA5c8tnabyJKmHSaG2kAGpudxuOf8ToDkhumF7UzME7DecbQE9uOZhNm7PuJg==
dependencies: dependencies:
yallist "^4.0.0" yallist "^4.0.0"
...@@ -3212,7 +3209,6 @@ minizlib@^1.2.1: ...@@ -3212,7 +3209,6 @@ minizlib@^1.2.1:
minizlib@^2.1.0: minizlib@^2.1.0:
version "2.1.0" version "2.1.0"
resolved "https://registry.yarnpkg.com/minizlib/-/minizlib-2.1.0.tgz#fd52c645301ef09a63a2c209697c294c6ce02cf3" resolved "https://registry.yarnpkg.com/minizlib/-/minizlib-2.1.0.tgz#fd52c645301ef09a63a2c209697c294c6ce02cf3"
integrity sha512-EzTZN/fjSvifSX0SlqUERCN39o6T40AMarPbv0MrarSFtIITCBh7bi+dU8nxGFHuqs9jdIAeoYoKuQAAASsPPA==
dependencies: dependencies:
minipass "^3.0.0" minipass "^3.0.0"
yallist "^4.0.0" yallist "^4.0.0"
...@@ -3249,7 +3245,6 @@ mkdirp@^0.5.1: ...@@ -3249,7 +3245,6 @@ mkdirp@^0.5.1:
mkdirp@^1.0.3: mkdirp@^1.0.3:
version "1.0.4" version "1.0.4"
resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-1.0.4.tgz#3eb5ed62622756d79a5f0e2a221dfebad75c2f7e" resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-1.0.4.tgz#3eb5ed62622756d79a5f0e2a221dfebad75c2f7e"
integrity sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==
mocha@^7.1.1: mocha@^7.1.1:
version "7.1.1" version "7.1.1"
...@@ -3707,7 +3702,6 @@ number-is-nan@^1.0.0: ...@@ -3707,7 +3702,6 @@ number-is-nan@^1.0.0:
nyc@^15.0.0: nyc@^15.0.0:
version "15.0.1" version "15.0.1"
resolved "https://registry.yarnpkg.com/nyc/-/nyc-15.0.1.tgz#bd4d5c2b17f2ec04370365a5ca1fc0ed26f9f93d" resolved "https://registry.yarnpkg.com/nyc/-/nyc-15.0.1.tgz#bd4d5c2b17f2ec04370365a5ca1fc0ed26f9f93d"
integrity sha512-n0MBXYBYRqa67IVt62qW1r/d9UH/Qtr7SF1w/nQLJ9KxvWF6b2xCHImRAixHN9tnMMYHC2P14uo6KddNGwMgGg==
dependencies: dependencies:
"@istanbuljs/load-nyc-config" "^1.0.0" "@istanbuljs/load-nyc-config" "^1.0.0"
"@istanbuljs/schema" "^0.1.2" "@istanbuljs/schema" "^0.1.2"
...@@ -5065,7 +5059,6 @@ tar@^4.4.10, tar@^4.4.12, tar@^4.4.13: ...@@ -5065,7 +5059,6 @@ tar@^4.4.10, tar@^4.4.12, tar@^4.4.13:
tar@^6.0.2: tar@^6.0.2:
version "6.0.2" version "6.0.2"
resolved "https://registry.yarnpkg.com/tar/-/tar-6.0.2.tgz#5df17813468a6264ff14f766886c622b84ae2f39" resolved "https://registry.yarnpkg.com/tar/-/tar-6.0.2.tgz#5df17813468a6264ff14f766886c622b84ae2f39"
integrity sha512-Glo3jkRtPcvpDlAs/0+hozav78yoXKFr+c4wgw62NNMO3oo4AaJdCo21Uu7lcwr55h39W2XD1LMERc64wtbItg==
dependencies: dependencies:
chownr "^2.0.0" chownr "^2.0.0"
fs-minipass "^2.0.0" fs-minipass "^2.0.0"
...@@ -5541,7 +5534,6 @@ yallist@^3.0.2, yallist@^3.0.3: ...@@ -5541,7 +5534,6 @@ yallist@^3.0.2, yallist@^3.0.3:
yallist@^4.0.0: yallist@^4.0.0:
version "4.0.0" version "4.0.0"
resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72" resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72"
integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==
yargs-parser@13.1.2, yargs-parser@^13.1.2: yargs-parser@13.1.2, yargs-parser@^13.1.2:
version "13.1.2" version "13.1.2"
......
...@@ -35,6 +35,8 @@ def update_training_service_config(args): ...@@ -35,6 +35,8 @@ def update_training_service_config(args):
config[args.ts]['paiConfig']['host'] = args.pai_host config[args.ts]['paiConfig']['host'] = args.pai_host
if args.pai_token is not None: if args.pai_token is not None:
config[args.ts]['paiConfig']['token'] = args.pai_token config[args.ts]['paiConfig']['token'] = args.pai_token
if args.pai_reuse is not None:
config[args.ts]['paiConfig']['reuse'] = args.pai_reuse.lower() == 'true'
if args.nni_docker_image is not None: if args.nni_docker_image is not None:
config[args.ts]['trial']['image'] = args.nni_docker_image config[args.ts]['trial']['image'] = args.nni_docker_image
if args.nni_manager_nfs_mount_path is not None: if args.nni_manager_nfs_mount_path is not None:
...@@ -101,6 +103,7 @@ if __name__ == '__main__': ...@@ -101,6 +103,7 @@ if __name__ == '__main__':
parser.add_argument("--output_dir", type=str) parser.add_argument("--output_dir", type=str)
parser.add_argument("--vc", type=str) parser.add_argument("--vc", type=str)
parser.add_argument("--pai_token", type=str) parser.add_argument("--pai_token", type=str)
parser.add_argument("--pai_reuse", type=str)
parser.add_argument("--pai_storage_config_name", type=str) parser.add_argument("--pai_storage_config_name", type=str)
parser.add_argument("--nni_manager_nfs_mount_path", type=str) parser.add_argument("--nni_manager_nfs_mount_path", type=str)
parser.add_argument("--container_nfs_mount_path", type=str) parser.add_argument("--container_nfs_mount_path", type=str)
......
...@@ -57,7 +57,7 @@ jobs: ...@@ -57,7 +57,7 @@ jobs:
echo "TEST_IMG:$TEST_IMG" echo "TEST_IMG:$TEST_IMG"
cd test cd test
python3 nni_test/nnitest/generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --nni_docker_image $TEST_IMG --pai_storage_config_name $(pai_storage_config_name)\ python3 nni_test/nnitest/generate_ts_config.py --ts pai --pai_reuse $(pai_reuse) --pai_host $(pai_host) --pai_user $(pai_user) --nni_docker_image $TEST_IMG --pai_storage_config_name $(pai_storage_config_name)\
--pai_token $(pai_token) --nni_manager_nfs_mount_path $(nni_manager_nfs_mount_path) --container_nfs_mount_path $(container_nfs_mount_path) --nni_manager_ip $(nni_manager_ip) --vc $(virtual_cluster) --pai_token $(pai_token) --nni_manager_nfs_mount_path $(nni_manager_nfs_mount_path) --container_nfs_mount_path $(container_nfs_mount_path) --nni_manager_ip $(nni_manager_ip) --vc $(virtual_cluster)
PATH=$HOME/.local/bin:$PATH python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts pai PATH=$HOME/.local/bin:$PATH python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts pai
displayName: 'integration test' displayName: 'integration test'
...@@ -14,10 +14,12 @@ def setType(key, valueType): ...@@ -14,10 +14,12 @@ def setType(key, valueType):
'''check key type''' '''check key type'''
return And(valueType, error=SCHEMA_TYPE_ERROR % (key, valueType.__name__)) return And(valueType, error=SCHEMA_TYPE_ERROR % (key, valueType.__name__))
def setChoice(key, *args): def setChoice(key, *args):
'''check choice''' '''check choice'''
return And(lambda n: n in args, error=SCHEMA_RANGE_ERROR % (key, str(args))) return And(lambda n: n in args, error=SCHEMA_RANGE_ERROR % (key, str(args)))
def setNumberRange(key, keyType, start, end): def setNumberRange(key, keyType, start, end):
'''check number range''' '''check number range'''
return And( return And(
...@@ -25,16 +27,19 @@ def setNumberRange(key, keyType, start, end): ...@@ -25,16 +27,19 @@ def setNumberRange(key, keyType, start, end):
And(lambda n: start <= n <= end, error=SCHEMA_RANGE_ERROR % (key, '(%s,%s)' % (start, end))), And(lambda n: start <= n <= end, error=SCHEMA_RANGE_ERROR % (key, '(%s,%s)' % (start, end))),
) )
def setPathCheck(key): def setPathCheck(key):
'''check if path exist''' '''check if path exist'''
return And(os.path.exists, error=SCHEMA_PATH_ERROR % key) return And(os.path.exists, error=SCHEMA_PATH_ERROR % key)
class AlgoSchema: class AlgoSchema:
""" """
This class is the schema of 'tuner', 'assessor' and 'advisor' sections of experiment configuraion file. This class is the schema of 'tuner', 'assessor' and 'advisor' sections of experiment configuraion file.
For example: For example:
AlgoSchema('tuner') creates the schema of tuner section. AlgoSchema('tuner') creates the schema of tuner section.
""" """
def __init__(self, algo_type): def __init__(self, algo_type):
""" """
Parameters: Parameters:
...@@ -108,6 +113,7 @@ class AlgoSchema: ...@@ -108,6 +113,7 @@ class AlgoSchema:
Schema(self.algo_schema).validate(data) Schema(self.algo_schema).validate(data)
self.validate_extras(data, self.algo_type) self.validate_extras(data, self.algo_type)
common_schema = { common_schema = {
'authorName': setType('authorName', str), 'authorName': setType('authorName', str),
'experimentName': setType('experimentName', str), 'experimentName': setType('experimentName', str),
...@@ -138,7 +144,7 @@ common_schema = { ...@@ -138,7 +144,7 @@ common_schema = {
} }
common_trial_schema = { common_trial_schema = {
'trial':{ 'trial': {
'command': setType('command', str), 'command': setType('command', str),
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
...@@ -147,7 +153,7 @@ common_trial_schema = { ...@@ -147,7 +153,7 @@ common_trial_schema = {
} }
pai_yarn_trial_schema = { pai_yarn_trial_schema = {
'trial':{ 'trial': {
'command': setType('command', str), 'command': setType('command', str),
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999), 'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
...@@ -156,9 +162,9 @@ pai_yarn_trial_schema = { ...@@ -156,9 +162,9 @@ pai_yarn_trial_schema = {
'image': setType('image', str), 'image': setType('image', str),
Optional('authFile'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'authFile'), Optional('authFile'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'authFile'),
Optional('shmMB'): setType('shmMB', int), Optional('shmMB'): setType('shmMB', int),
Optional('dataDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\ Optional('dataDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),
error='ERROR: dataDir format error, dataDir format is hdfs://xxx.xxx.xxx.xxx:xxx'), error='ERROR: dataDir format error, dataDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\ Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),
error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'), error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
Optional('virtualCluster'): setType('virtualCluster', str), Optional('virtualCluster'): setType('virtualCluster', str),
Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'), Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
...@@ -184,7 +190,7 @@ pai_yarn_config_schema = { ...@@ -184,7 +190,7 @@ pai_yarn_config_schema = {
pai_trial_schema = { pai_trial_schema = {
'trial':{ 'trial': {
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'), 'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'),
'containerNFSMountPath': setType('containerNFSMountPath', str), 'containerNFSMountPath': setType('containerNFSMountPath', str),
...@@ -200,21 +206,21 @@ pai_trial_schema = { ...@@ -200,21 +206,21 @@ pai_trial_schema = {
} }
pai_config_schema = { pai_config_schema = {
'paiConfig': Or({ 'paiConfig': {
'userName': setType('userName', str),
'passWord': setType('passWord', str),
'host': setType('host', str),
Optional('reuse'): setType('reuse', bool)
}, {
'userName': setType('userName', str), 'userName': setType('userName', str),
'token': setType('token', str), Or('passWord', 'token', only_one=True): str,
'host': setType('host', str), 'host': setType('host', str),
Optional('reuse'): setType('reuse', bool) Optional('reuse'): setType('reuse', bool),
}) Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
Optional('memoryMB'): setType('memoryMB', int),
Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
Optional('useActiveGpu'): setType('useActiveGpu', bool),
}
} }
dlts_trial_schema = { dlts_trial_schema = {
'trial':{ 'trial': {
'command': setType('command', str), 'command': setType('command', str),
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999), 'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
...@@ -235,7 +241,7 @@ dlts_config_schema = { ...@@ -235,7 +241,7 @@ dlts_config_schema = {
} }
aml_trial_schema = { aml_trial_schema = {
'trial':{ 'trial': {
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
'command': setType('command', str), 'command': setType('command', str),
'image': setType('image', str), 'image': setType('image', str),
...@@ -252,7 +258,7 @@ aml_config_schema = { ...@@ -252,7 +258,7 @@ aml_config_schema = {
} }
kubeflow_trial_schema = { kubeflow_trial_schema = {
'trial':{ 'trial': {
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'), Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
Optional('ps'): { Optional('ps'): {
...@@ -273,7 +279,7 @@ kubeflow_trial_schema = { ...@@ -273,7 +279,7 @@ kubeflow_trial_schema = {
'image': setType('image', str), 'image': setType('image', str),
Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath') Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
}, },
Optional('worker'):{ Optional('worker'): {
'replicas': setType('replicas', int), 'replicas': setType('replicas', int),
'command': setType('command', str), 'command': setType('command', str),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999), 'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
...@@ -286,7 +292,7 @@ kubeflow_trial_schema = { ...@@ -286,7 +292,7 @@ kubeflow_trial_schema = {
} }
kubeflow_config_schema = { kubeflow_config_schema = {
'kubeflowConfig':Or({ 'kubeflowConfig': Or({
'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'), 'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
'apiVersion': setType('apiVersion', str), 'apiVersion': setType('apiVersion', str),
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'), Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
...@@ -299,15 +305,15 @@ kubeflow_config_schema = { ...@@ -299,15 +305,15 @@ kubeflow_config_schema = {
'apiVersion': setType('apiVersion', str), 'apiVersion': setType('apiVersion', str),
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'), Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
'keyVault': { 'keyVault': {
'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\ 'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'), error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\ 'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)') error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
}, },
'azureStorage': { 'azureStorage': {
'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),\ 'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),
error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'), error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),\ 'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),
error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)') error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
}, },
Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999) Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999)
...@@ -315,7 +321,7 @@ kubeflow_config_schema = { ...@@ -315,7 +321,7 @@ kubeflow_config_schema = {
} }
frameworkcontroller_trial_schema = { frameworkcontroller_trial_schema = {
'trial':{ 'trial': {
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
'taskRoles': [{ 'taskRoles': [{
'name': setType('name', str), 'name': setType('name', str),
...@@ -335,7 +341,7 @@ frameworkcontroller_trial_schema = { ...@@ -335,7 +341,7 @@ frameworkcontroller_trial_schema = {
} }
frameworkcontroller_config_schema = { frameworkcontroller_config_schema = {
'frameworkcontrollerConfig':Or({ 'frameworkcontrollerConfig': Or({
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'), Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
Optional('serviceAccountName'): setType('serviceAccountName', str), Optional('serviceAccountName'): setType('serviceAccountName', str),
'nfs': { 'nfs': {
...@@ -346,15 +352,15 @@ frameworkcontroller_config_schema = { ...@@ -346,15 +352,15 @@ frameworkcontroller_config_schema = {
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'), Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
Optional('serviceAccountName'): setType('serviceAccountName', str), Optional('serviceAccountName'): setType('serviceAccountName', str),
'keyVault': { 'keyVault': {
'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\ 'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'), error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\ 'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)') error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
}, },
'azureStorage': { 'azureStorage': {
'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),\ 'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),
error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'), error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),\ 'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),
error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)') error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
}, },
Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999) Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999)
...@@ -362,7 +368,7 @@ frameworkcontroller_config_schema = { ...@@ -362,7 +368,7 @@ frameworkcontroller_config_schema = {
} }
machine_list_schema = { machine_list_schema = {
'machineList':[Or( 'machineList': [Or(
{ {
'ip': setType('ip', str), 'ip': setType('ip', str),
Optional('port'): setNumberRange('port', int, 1, 65535), Optional('port'): setNumberRange('port', int, 1, 65535),
...@@ -395,6 +401,7 @@ training_service_schema_dict = { ...@@ -395,6 +401,7 @@ training_service_schema_dict = {
'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}), 'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}),
} }
class NNIConfigSchema: class NNIConfigSchema:
def validate(self, data): def validate(self, data):
train_service = data['trainingServicePlatform'] train_service = data['trainingServicePlatform']
...@@ -483,11 +490,17 @@ class NNIConfigSchema: ...@@ -483,11 +490,17 @@ class NNIConfigSchema:
if not taskRoles_dict: if not taskRoles_dict:
raise SchemaError('Please set taskRoles in paiConfigPath config file!') raise SchemaError('Please set taskRoles in paiConfigPath config file!')
else: else:
pai_trial_fields_required_list = ['image', 'gpuNum', 'cpuNum', 'memoryMB', 'paiStorageConfigName', 'command'] pai_trial_fields_required_list = ['image', 'paiStorageConfigName', 'command']
for trial_field in pai_trial_fields_required_list: for trial_field in pai_trial_fields_required_list:
if experiment_config['trial'].get(trial_field) is None: if experiment_config['trial'].get(trial_field) is None:
raise SchemaError('Please set {0} in trial configuration,\ raise SchemaError('Please set {0} in trial configuration,\
or set additional pai configuration file path in paiConfigPath!'.format(trial_field)) or set additional pai configuration file path in paiConfigPath!'.format(trial_field))
pai_resource_fields_required_list = ['gpuNum', 'cpuNum', 'memoryMB']
for required_field in pai_resource_fields_required_list:
if experiment_config['trial'].get(required_field) is None and \
experiment_config['paiConfig'].get(required_field) is None:
raise SchemaError('Please set {0} in trial or paiConfig configuration,\
or set additional pai configuration file path in paiConfigPath!'.format(required_field))
def validate_pai_trial_conifg(self, experiment_config): def validate_pai_trial_conifg(self, experiment_config):
'''validate the trial config in pai platform''' '''validate the trial config in pai platform'''
...@@ -495,7 +508,7 @@ class NNIConfigSchema: ...@@ -495,7 +508,7 @@ class NNIConfigSchema:
if experiment_config.get('trial').get('shmMB') and \ if experiment_config.get('trial').get('shmMB') and \
experiment_config['trial']['shmMB'] > experiment_config['trial']['memoryMB']: experiment_config['trial']['shmMB'] > experiment_config['trial']['memoryMB']:
raise SchemaError('shmMB should be no more than memoryMB!') raise SchemaError('shmMB should be no more than memoryMB!')
#backward compatibility # backward compatibility
warning_information = '{0} is not supported in NNI anymore, please remove the field in config file!\ warning_information = '{0} is not supported in NNI anymore, please remove the field in config file!\
please refer https://github.com/microsoft/nni/blob/master/docs/en_US/TrainingService/PaiMode.md#run-an-experiment\ please refer https://github.com/microsoft/nni/blob/master/docs/en_US/TrainingService/PaiMode.md#run-an-experiment\
for the practices of how to get data and output model in trial code' for the practices of how to get data and output model in trial code'
......
...@@ -57,7 +57,11 @@ class BaseChannel(ABC): ...@@ -57,7 +57,11 @@ class BaseChannel(ABC):
def close(self): def close(self):
self.is_running = False self.is_running = False
try:
self._inner_close() self._inner_close()
except Exception as err:
# ignore any error on closing
print("error on closing channel: %s" % err)
def send(self, command, data): def send(self, command, data):
"""Send command to Training Service. """Send command to Training Service.
......
...@@ -82,7 +82,11 @@ class RemoteLogger(object): ...@@ -82,7 +82,11 @@ class RemoteLogger(object):
''' '''
constructor constructor
''' '''
self.logger = logging.getLogger('nni_syslog_{}'.format(tag)) logger_name = 'nni_syslog_{}'.format(tag)
# to prevent multiple trial logged in same logger
if trial_id is not None:
logger_name = '{}_{}'.format(logger_name, trial_id)
self.logger = logging.getLogger(logger_name)
self.log_level = log_level self.log_level = log_level
self.logger.setLevel(self.log_level) self.logger.setLevel(self.log_level)
self.pipeReader = None self.pipeReader = None
......
...@@ -86,11 +86,17 @@ class Trial: ...@@ -86,11 +86,17 @@ class Trial:
break break
time.sleep(0.1) time.sleep(0.1)
trial_command = self.args.trial_command
gpuIndices = self.data.get('gpuIndices')
if (gpuIndices is not None):
trial_command = 'CUDA_VISIBLE_DEVICES="%s " %s' % (gpuIndices, trial_command)
self.log_pipe_stdout = self.trial_syslogger_stdout.get_pipelog_reader() self.log_pipe_stdout = self.trial_syslogger_stdout.get_pipelog_reader()
self.process = Popen(self.args.trial_command, shell=True, stdout=self.log_pipe_stdout, self.process = Popen(trial_command, shell=True, stdout=self.log_pipe_stdout,
stderr=self.log_pipe_stdout, cwd=trial_code_dir, env=dict(environ)) stderr=self.log_pipe_stdout, cwd=trial_code_dir, env=dict(environ))
nni_log(LogType.Info, '{0}: spawns a subprocess (pid {1}) to run command: {2}'. nni_log(LogType.Info, '{0}: spawns a subprocess (pid {1}) to run command: {2}'.
format(self.name, self.process.pid, shlex.split(self.args.trial_command))) format(self.name, self.process.pid, shlex.split(trial_command)))
def save_parameter_file(self, command_data): def save_parameter_file(self, command_data):
parameters = command_data["parameters"] parameters = command_data["parameters"]
......
...@@ -37,9 +37,9 @@ class WebChannel(BaseChannel): ...@@ -37,9 +37,9 @@ class WebChannel(BaseChannel):
def _inner_close(self): def _inner_close(self):
if self.client is not None: if self.client is not None:
self.client.close() self.client.close()
if self._event_loop.is_running():
self._event_loop.close()
self.client = None self.client = None
if self._event_loop.is_running():
self._event_loop.stop()
self._event_loop = None self._event_loop = None
def _inner_send(self, message): def _inner_send(self, message):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment