Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
13dc0f8f
Commit
13dc0f8f
authored
Mar 21, 2022
by
liuzhe
Browse files
Merge branch 'master' into doc-refactor
parents
22165cea
3b27ac76
Changes
49
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
381 additions
and
90 deletions
+381
-90
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+82
-1
test/ut/sdk/test_serializer.py
test/ut/sdk/test_serializer.py
+83
-10
ts/jupyter_extension/yarn.lock
ts/jupyter_extension/yarn.lock
+3
-3
ts/nni_manager/common/globals/arguments.ts
ts/nni_manager/common/globals/arguments.ts
+96
-0
ts/nni_manager/main.ts
ts/nni_manager/main.ts
+18
-74
ts/nni_manager/package.json
ts/nni_manager/package.json
+3
-1
ts/nni_manager/test/common/globals/arguments.test.ts
ts/nni_manager/test/common/globals/arguments.test.ts
+69
-0
ts/nni_manager/training_service/reusable/trialDispatcher.ts
ts/nni_manager/training_service/reusable/trialDispatcher.ts
+1
-0
ts/nni_manager/yarn.lock
ts/nni_manager/yarn.lock
+26
-1
No files found.
test/ut/retiarii/test_highlevel_apis.py
View file @
13dc0f8f
...
@@ -17,7 +17,7 @@ from nni.retiarii.graph import Model
...
@@ -17,7 +17,7 @@ from nni.retiarii.graph import Model
from
nni.retiarii.nn.pytorch.api
import
ValueChoice
from
nni.retiarii.nn.pytorch.api
import
ValueChoice
from
nni.retiarii.nn.pytorch.mutator
import
process_evaluator_mutations
,
process_inline_mutation
,
extract_mutation_from_pt_module
from
nni.retiarii.nn.pytorch.mutator
import
process_evaluator_mutations
,
process_inline_mutation
,
extract_mutation_from_pt_module
from
nni.retiarii.serializer
import
model_wrapper
from
nni.retiarii.serializer
import
model_wrapper
from
nni.retiarii.utils
import
ContextStack
from
nni.retiarii.utils
import
ContextStack
,
original_state_dict_hooks
class
EnumerateSampler
(
Sampler
):
class
EnumerateSampler
(
Sampler
):
...
@@ -123,6 +123,29 @@ class GraphIR(unittest.TestCase):
...
@@ -123,6 +123,29 @@ class GraphIR(unittest.TestCase):
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model_new
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model_new
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
i
,
3
,
3
]))
torch
.
Size
([
1
,
i
,
3
,
3
]))
def
test_layer_choice_weight_inheritance
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
module
=
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
i
,
kernel_size
=
1
)
for
i
in
range
(
1
,
11
)])
def
forward
(
self
,
x
):
return
self
.
module
(
x
)
orig_model
=
Net
()
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
for
i
in
range
(
1
,
11
):
model_new
=
mutator
.
apply
(
model
)
model_new
=
self
.
_get_converted_pytorch_model
(
model_new
)
with
original_state_dict_hooks
(
model_new
):
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
inp
=
torch
.
randn
(
1
,
3
,
3
,
3
)
a
=
getattr
(
orig_model
.
module
,
str
(
i
-
1
))(
inp
)
b
=
model_new
(
inp
)
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
def
test_nested_layer_choice
(
self
):
def
test_nested_layer_choice
(
self
):
@
model_wrapper
@
model_wrapper
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
...
@@ -150,6 +173,40 @@ class GraphIR(unittest.TestCase):
...
@@ -150,6 +173,40 @@ class GraphIR(unittest.TestCase):
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
mutators
[
1
].
apply
(
mutators
[
0
].
apply
(
model
)))(
input
).
size
(),
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
mutators
[
1
].
apply
(
mutators
[
0
].
apply
(
model
)))(
input
).
size
(),
torch
.
Size
([
1
,
5
,
5
,
5
]))
torch
.
Size
([
1
,
5
,
5
,
5
]))
def
test_nested_layer_choice_weight_inheritance
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
module
=
nn
.
LayerChoice
([
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
3
,
kernel_size
=
1
),
nn
.
Conv2d
(
3
,
4
,
kernel_size
=
1
),
nn
.
Conv2d
(
3
,
5
,
kernel_size
=
1
)]),
nn
.
Conv2d
(
3
,
1
,
kernel_size
=
1
)
])
def
forward
(
self
,
x
):
return
self
.
module
(
x
)
orig_model
=
Net
()
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
mutators
[
1
].
bind_sampler
(
EnumerateSampler
())
input
=
torch
.
randn
(
1
,
3
,
5
,
5
)
for
i
in
range
(
3
):
model_new
=
self
.
_get_converted_pytorch_model
(
mutators
[
1
].
apply
(
mutators
[
0
].
apply
(
model
)))
with
original_state_dict_hooks
(
model_new
):
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
if
i
==
0
:
a
=
getattr
(
getattr
(
orig_model
.
module
,
'0'
),
'0'
)(
input
)
elif
i
==
1
:
a
=
getattr
(
orig_model
.
module
,
'1'
)(
input
)
elif
i
==
2
:
a
=
getattr
(
getattr
(
orig_model
.
module
,
'0'
),
'2'
)(
input
)
b
=
model_new
(
input
)
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
def
test_input_choice
(
self
):
def
test_input_choice
(
self
):
@
model_wrapper
@
model_wrapper
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
...
@@ -578,6 +635,30 @@ class GraphIR(unittest.TestCase):
...
@@ -578,6 +635,30 @@ class GraphIR(unittest.TestCase):
self
.
assertIn
(
1.
,
result
)
self
.
assertIn
(
1.
,
result
)
def
test_repeat_weight_inheritance
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
module
=
nn
.
Repeat
(
lambda
index
:
nn
.
Conv2d
(
3
,
3
,
1
),
(
2
,
5
))
def
forward
(
self
,
x
):
return
self
.
module
(
x
)
orig_model
=
Net
()
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
inp
=
torch
.
randn
(
1
,
3
,
5
,
5
)
for
i
in
range
(
4
):
model_new
=
self
.
_get_converted_pytorch_model
(
mutator
.
apply
(
model
))
with
original_state_dict_hooks
(
model_new
):
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
a
=
nn
.
Sequential
(
*
orig_model
.
module
.
blocks
[:
i
+
2
])(
inp
)
b
=
model_new
(
inp
)
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
def
test_cell
(
self
):
def
test_cell
(
self
):
@
model_wrapper
@
model_wrapper
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
...
...
test/ut/sdk/test_serializer.py
View file @
13dc0f8f
import
math
import
math
import
pickle
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -27,6 +28,11 @@ class SimpleClass:
...
@@ -27,6 +28,11 @@ class SimpleClass:
self
.
_b
=
b
self
.
_b
=
b
@
nni
.
trace
class
EmptyClass
:
pass
class
UnserializableSimpleClass
:
class
UnserializableSimpleClass
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_a
=
1
self
.
_a
=
1
...
@@ -124,7 +130,8 @@ def test_custom_class():
...
@@ -124,7 +130,8 @@ def test_custom_class():
module
=
nni
.
trace
(
Foo
)(
Foo
(
1
),
5
)
module
=
nni
.
trace
(
Foo
)(
Foo
(
1
),
5
)
dumped_module
=
nni
.
dump
(
module
)
dumped_module
=
nni
.
dump
(
module
)
assert
len
(
dumped_module
)
>
200
# should not be too longer if the serialization is correct
module
=
nni
.
load
(
dumped_module
)
assert
module
.
bb
[
0
]
==
module
.
bb
[
999
]
==
6
module
=
nni
.
trace
(
Foo
)(
nni
.
trace
(
Foo
)(
1
),
5
)
module
=
nni
.
trace
(
Foo
)(
nni
.
trace
(
Foo
)(
1
),
5
)
dumped_module
=
nni
.
dump
(
module
)
dumped_module
=
nni
.
dump
(
module
)
...
@@ -193,6 +200,20 @@ def test_dataset():
...
@@ -193,6 +200,20 @@ def test_dataset():
assert
y
.
size
()
==
torch
.
Size
([
10
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
def
test_pickle
():
pickle
.
dumps
(
EmptyClass
())
obj
=
SimpleClass
(
1
)
obj
=
pickle
.
loads
(
pickle
.
dumps
(
obj
))
assert
obj
.
_a
==
1
assert
obj
.
_b
==
1
obj
=
SimpleClass
(
1
)
obj
.
xxx
=
3
obj
=
pickle
.
loads
(
pickle
.
dumps
(
obj
))
assert
obj
.
xxx
==
3
@
pytest
.
mark
.
skipif
(
sys
.
platform
!=
'linux'
,
reason
=
'https://github.com/microsoft/nni/issues/4434'
)
@
pytest
.
mark
.
skipif
(
sys
.
platform
!=
'linux'
,
reason
=
'https://github.com/microsoft/nni/issues/4434'
)
def
test_multiprocessing_dataloader
():
def
test_multiprocessing_dataloader
():
# check whether multi-processing works
# check whether multi-processing works
...
@@ -208,6 +229,28 @@ def test_multiprocessing_dataloader():
...
@@ -208,6 +229,28 @@ def test_multiprocessing_dataloader():
assert
y
.
size
()
==
torch
.
Size
([
10
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
def
_test_multiprocessing_dataset_worker
(
dataset
):
if
sys
.
platform
==
'linux'
:
# on non-linux, the loaded object will become non-traceable
# due to an implementation limitation
assert
is_traceable
(
dataset
)
else
:
from
torch.utils.data
import
Dataset
assert
isinstance
(
dataset
,
Dataset
)
def
test_multiprocessing_dataset
():
from
torch.utils.data
import
Dataset
dataset
=
nni
.
trace
(
Dataset
)()
import
multiprocessing
process
=
multiprocessing
.
Process
(
target
=
_test_multiprocessing_dataset_worker
,
args
=
(
dataset
,
))
process
.
start
()
process
.
join
()
assert
process
.
exitcode
==
0
def
test_type
():
def
test_type
():
assert
nni
.
dump
(
torch
.
optim
.
Adam
)
==
'{"__nni_type__": "path:torch.optim.adam.Adam"}'
assert
nni
.
dump
(
torch
.
optim
.
Adam
)
==
'{"__nni_type__": "path:torch.optim.adam.Adam"}'
assert
nni
.
load
(
'{"__nni_type__": "path:torch.optim.adam.Adam"}'
)
==
torch
.
optim
.
Adam
assert
nni
.
load
(
'{"__nni_type__": "path:torch.optim.adam.Adam"}'
)
==
torch
.
optim
.
Adam
...
@@ -220,10 +263,20 @@ def test_lightning_earlystop():
...
@@ -220,10 +263,20 @@ def test_lightning_earlystop():
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
from
pytorch_lightning.callbacks.early_stopping
import
EarlyStopping
from
pytorch_lightning.callbacks.early_stopping
import
EarlyStopping
trainer
=
pl
.
Trainer
(
callbacks
=
[
nni
.
trace
(
EarlyStopping
)(
monitor
=
"val_loss"
)])
trainer
=
pl
.
Trainer
(
callbacks
=
[
nni
.
trace
(
EarlyStopping
)(
monitor
=
"val_loss"
)])
trainer
=
nni
.
load
(
nni
.
dump
(
trainer
))
pickle_size_limit
=
4096
if
sys
.
platform
==
'linux'
else
32768
trainer
=
nni
.
load
(
nni
.
dump
(
trainer
,
pickle_size_limit
=
pickle_size_limit
))
assert
any
(
isinstance
(
callback
,
EarlyStopping
)
for
callback
in
trainer
.
callbacks
)
assert
any
(
isinstance
(
callback
,
EarlyStopping
)
for
callback
in
trainer
.
callbacks
)
def
test_pickle_trainer
():
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
from
pytorch_lightning
import
Trainer
trainer
=
pl
.
Trainer
(
max_epochs
=
1
)
data
=
pickle
.
dumps
(
trainer
)
trainer
=
pickle
.
loads
(
data
)
assert
isinstance
(
trainer
,
Trainer
)
def
test_generator
():
def
test_generator
():
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
...
@@ -272,11 +325,31 @@ def test_arguments_kind():
...
@@ -272,11 +325,31 @@ def test_arguments_kind():
assert
lstm
.
trace_kwargs
==
{
'input_size'
:
2
,
'hidden_size'
:
2
}
assert
lstm
.
trace_kwargs
==
{
'input_size'
:
2
,
'hidden_size'
:
2
}
if
__name__
==
'__main__'
:
def
test_subclass
():
# test_simple_class()
@
nni
.
trace
# test_external_class()
class
Super
:
# test_nested_class()
def
__init__
(
self
,
a
,
b
):
# test_unserializable()
self
.
_a
=
a
# test_basic_unit()
self
.
_b
=
b
# test_generator()
test_arguments_kind
()
class
Sub1
(
Super
):
def
__init__
(
self
,
c
,
d
):
super
().
__init__
(
3
,
4
)
self
.
_c
=
c
self
.
_d
=
d
@
nni
.
trace
class
Sub2
(
Super
):
def
__init__
(
self
,
c
,
d
):
super
().
__init__
(
3
,
4
)
self
.
_c
=
c
self
.
_d
=
d
obj
=
Sub1
(
1
,
2
)
# There could be trace_kwargs for obj. Behavior is undefined.
assert
obj
.
_a
==
3
and
obj
.
_c
==
1
assert
isinstance
(
obj
,
Super
)
obj
=
Sub2
(
1
,
2
)
assert
obj
.
trace_kwargs
==
{
'c'
:
1
,
'd'
:
2
}
assert
issubclass
(
type
(
obj
),
Super
)
assert
isinstance
(
obj
,
Super
)
ts/jupyter_extension/yarn.lock
View file @
13dc0f8f
...
@@ -2907,9 +2907,9 @@ url-parse-lax@^3.0.0:
...
@@ -2907,9 +2907,9 @@ url-parse-lax@^3.0.0:
prepend-http "^2.0.0"
prepend-http "^2.0.0"
url-parse@~1.5.1:
url-parse@~1.5.1:
version "1.5.
7
"
version "1.5.
10
"
resolved "https://registry.yarnpkg.com/url-parse/-/url-parse-1.5.
7
.tgz#
00780f60dbdae9018
1f
5
1e
d85fb24109422c932a
"
resolved "https://registry.yarnpkg.com/url-parse/-/url-parse-1.5.
10
.tgz#
9d3c2f736c1d75dd3bd2be507dcc11
1f1e
2ea9c1
"
integrity sha512-
HxWkieX+STA38EDk7CE9MEryFeHCKzgagxlGvsdS7WBImq9Mk+PGwiT56w82WI3aicwJA8REp42Cxo98c8FZMA
==
integrity sha512-
WypcfiRhfeUP9vvF0j6rw0J3hrWrw6iZv3+22h6iRMJ/8z1Tj6XfLP4DsUix5MhMPnXpiHDoKyoZ/bdCkwBCiQ
==
dependencies:
dependencies:
querystringify "^2.1.1"
querystringify "^2.1.1"
requires-port "^1.0.0"
requires-port "^1.0.0"
...
...
ts/nni_manager/common/globals/arguments.ts
0 → 100644
View file @
13dc0f8f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/**
* Parse NNI manager's command line arguments.
**/
import
assert
from
'
assert/strict
'
;
import
yargs
from
'
yargs/yargs
'
;
/**
* Command line arguments provided by "nni/experiment/launcher.py".
*
* Hyphen-separated words are automatically converted to camelCases by yargs lib, but snake_cases are not.
* So it supports "--log-level" but does not support "--log_level".
*
* Unfortunately I misunderstood "experiment_working_directory" config field when deciding the name.
* It defaults to "~/nni-experiments" rather than "~/nni-experiments/<experiment-id>",
* and further more the working directory is "site-packages/nni_node", not either.
* For compatibility concern we cannot change the public API, so there is an inconsistency here.
**/
export
interface
NniManagerArgs
{
readonly
port
:
number
;
readonly
experimentId
:
string
;
readonly
action
:
'
create
'
|
'
resume
'
|
'
view
'
;
readonly
experimentsDirectory
:
string
;
// renamed "config.experiment_working_directory", must be absolute
readonly
logLevel
:
'
critical
'
|
'
error
'
|
'
warning
'
|
'
info
'
|
'
debug
'
;
readonly
foreground
:
boolean
;
readonly
urlPrefix
:
string
;
// leading and trailing "/" must be stripped
// these are planned to be removed
readonly
mode
:
string
;
readonly
dispatcherPipe
:
string
|
undefined
;
}
export
function
parseArgs
(
rawArgs
:
string
[]):
NniManagerArgs
{
const
parser
=
yargs
(
rawArgs
).
options
(
yargsOptions
).
strict
().
fail
((
_msg
,
err
,
_yargs
)
=>
{
throw
err
;
});
const
parsedArgs
:
NniManagerArgs
=
parser
.
parseSync
();
// strip yargs leftovers
const
argsAsAny
:
any
=
{};
for
(
const
key
in
yargsOptions
)
{
argsAsAny
[
key
]
=
(
parsedArgs
as
any
)[
key
];
assert
(
!
Number
.
isNaN
(
argsAsAny
[
key
]),
`Command line arg --
${
key
}
is not a number`
);
}
if
(
argsAsAny
.
dispatcherPipe
===
''
)
{
argsAsAny
.
dispatcherPipe
=
undefined
;
}
const
args
:
NniManagerArgs
=
argsAsAny
;
const
prefixErrMsg
=
`Command line arg --url-prefix "
${
args
.
urlPrefix
}
" is not stripped`
;
assert
(
!
args
.
urlPrefix
.
startsWith
(
'
/
'
)
&&
!
args
.
urlPrefix
.
endsWith
(
'
/
'
),
prefixErrMsg
);
return
args
;
}
const
yargsOptions
=
{
port
:
{
demandOption
:
true
,
type
:
'
number
'
},
experimentId
:
{
demandOption
:
true
,
type
:
'
string
'
},
action
:
{
choices
:
[
'
create
'
,
'
resume
'
,
'
view
'
]
as
const
,
demandOption
:
true
},
experimentsDirectory
:
{
demandOption
:
true
,
type
:
'
string
'
},
logLevel
:
{
choices
:
[
'
critical
'
,
'
error
'
,
'
warning
'
,
'
info
'
,
'
debug
'
]
as
const
,
demandOption
:
true
},
foreground
:
{
default
:
false
,
type
:
'
boolean
'
},
urlPrefix
:
{
default
:
''
,
type
:
'
string
'
},
mode
:
{
default
:
''
,
type
:
'
string
'
},
dispatcherPipe
:
{
default
:
''
,
type
:
'
string
'
}
}
as
const
;
ts/nni_manager/main.ts
View file @
13dc0f8f
...
@@ -20,15 +20,11 @@ import { SqlDB } from './core/sqlDatabase';
...
@@ -20,15 +20,11 @@ import { SqlDB } from './core/sqlDatabase';
import
{
NNIExperimentsManager
}
from
'
./core/nniExperimentsManager
'
;
import
{
NNIExperimentsManager
}
from
'
./core/nniExperimentsManager
'
;
import
{
NNITensorboardManager
}
from
'
./core/nniTensorboardManager
'
;
import
{
NNITensorboardManager
}
from
'
./core/nniTensorboardManager
'
;
import
{
RestServer
}
from
'
./rest_server
'
;
import
{
RestServer
}
from
'
./rest_server
'
;
import
{
parseArgs
}
from
'
common/globals/arguments
'
;
function
initStartupInfo
(
const
args
=
parseArgs
(
process
.
argv
.
slice
(
2
));
startExpMode
:
string
,
experimentId
:
string
,
basePort
:
number
,
platform
:
string
,
logDirectory
:
string
,
experimentLogLevel
:
string
,
readonly
:
boolean
,
dispatcherPipe
:
string
,
urlprefix
:
string
):
void
{
const
createNew
:
boolean
=
(
startExpMode
===
ExperimentStartUpMode
.
NEW
);
setExperimentStartupInfo
(
createNew
,
experimentId
,
basePort
,
platform
,
logDirectory
,
experimentLogLevel
,
readonly
,
dispatcherPipe
,
urlprefix
);
}
async
function
initContainer
(
foreground
:
boolean
,
_platformMode
:
string
,
logFileName
?:
string
):
Promise
<
void
>
{
async
function
initContainer
():
Promise
<
void
>
{
Container
.
bind
(
Manager
)
Container
.
bind
(
Manager
)
.
to
(
NNIManager
)
.
to
(
NNIManager
)
.
scope
(
Scope
.
Singleton
);
.
scope
(
Scope
.
Singleton
);
...
@@ -45,84 +41,32 @@ async function initContainer(foreground: boolean, _platformMode: string, logFile
...
@@ -45,84 +41,32 @@ async function initContainer(foreground: boolean, _platformMode: string, logFile
.
to
(
NNITensorboardManager
)
.
to
(
NNITensorboardManager
)
.
scope
(
Scope
.
Singleton
);
.
scope
(
Scope
.
Singleton
);
const
DEFAULT_LOGFILE
:
string
=
path
.
join
(
getLogDir
(),
'
nnimanager.log
'
);
const
DEFAULT_LOGFILE
:
string
=
path
.
join
(
getLogDir
(),
'
nnimanager.log
'
);
if
(
!
foreground
)
{
if
(
!
args
.
foreground
)
{
if
(
logFileName
===
undefined
)
{
startLogging
(
DEFAULT_LOGFILE
);
startLogging
(
DEFAULT_LOGFILE
);
}
else
{
startLogging
(
logFileName
);
}
}
}
// eslint-disable-next-line @typescript-eslint/no-use-before-define
// eslint-disable-next-line @typescript-eslint/no-use-before-define
setLogLevel
(
logLevel
);
setLogLevel
(
args
.
logLevel
);
const
ds
:
DataStore
=
component
.
get
(
DataStore
);
const
ds
:
DataStore
=
component
.
get
(
DataStore
);
await
ds
.
init
();
await
ds
.
init
();
}
}
function
usage
():
void
{
setExperimentStartupInfo
(
console
.
info
(
'
usage: node main.js --port <port> --mode
\
args
.
action
===
'
create
'
,
<local/remote/pai/kubeflow/frameworkcontroller/aml/adl/hybrid/dlc> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>
'
);
args
.
experimentId
,
}
args
.
port
,
args
.
mode
,
const
strPort
:
string
=
parseArg
([
'
--port
'
,
'
-p
'
]);
args
.
experimentsDirectory
,
if
(
!
strPort
||
strPort
.
length
===
0
)
{
args
.
logLevel
,
usage
();
args
.
action
===
'
view
'
,
process
.
exit
(
1
);
args
.
dispatcherPipe
??
''
,
}
args
.
urlPrefix
);
const
foregroundArg
:
string
=
parseArg
([
'
--foreground
'
,
'
-f
'
]);
if
(
foregroundArg
&&
!
[
'
true
'
,
'
false
'
].
includes
(
foregroundArg
.
toLowerCase
()))
{
console
.
log
(
`FATAL: foreground property should only be true or false`
);
usage
();
process
.
exit
(
1
);
}
const
foreground
:
boolean
=
(
foregroundArg
&&
foregroundArg
.
toLowerCase
()
===
'
true
'
)
?
true
:
false
;
const
port
:
number
=
parseInt
(
strPort
,
10
);
const
mode
:
string
=
parseArg
([
'
--mode
'
,
'
-m
'
]);
const
startMode
:
string
=
parseArg
([
'
--start_mode
'
,
'
-s
'
]);
if
(
!
[
ExperimentStartUpMode
.
NEW
,
ExperimentStartUpMode
.
RESUME
].
includes
(
startMode
))
{
console
.
log
(
`FATAL: unknown start_mode:
${
startMode
}
`
);
usage
();
process
.
exit
(
1
);
}
const
experimentId
:
string
=
parseArg
([
'
--experiment_id
'
,
'
-id
'
]);
if
(
experimentId
.
trim
().
length
<
1
)
{
console
.
log
(
`FATAL: cannot resume the experiment, invalid experiment_id:
${
experimentId
}
`
);
usage
();
process
.
exit
(
1
);
}
const
logDir
:
string
=
parseArg
([
'
--log_dir
'
,
'
-ld
'
]);
if
(
logDir
.
length
>
0
)
{
if
(
!
fs
.
existsSync
(
logDir
))
{
console
.
log
(
`FATAL: log_dir
${
logDir
}
does not exist`
);
}
}
const
logLevel
:
string
=
parseArg
([
'
--log_level
'
,
'
-ll
'
]);
const
readonlyArg
:
string
=
parseArg
([
'
--readonly
'
,
'
-r
'
]);
if
(
readonlyArg
&&
!
[
'
true
'
,
'
false
'
].
includes
(
readonlyArg
.
toLowerCase
()))
{
console
.
log
(
`FATAL: readonly property should only be true or false`
);
usage
();
process
.
exit
(
1
);
}
const
readonly
=
(
readonlyArg
&&
readonlyArg
.
toLowerCase
()
==
'
true
'
)
?
true
:
false
;
const
dispatcherPipe
:
string
=
parseArg
([
'
--dispatcher_pipe
'
]);
const
urlPrefix
:
string
=
parseArg
([
'
--url_prefix
'
]);
initStartupInfo
(
startMode
,
experimentId
,
port
,
mode
,
logDir
,
logLevel
,
readonly
,
dispatcherPipe
,
urlPrefix
);
mkDirP
(
getLogDir
())
mkDirP
(
getLogDir
())
.
then
(
async
()
=>
{
.
then
(
async
()
=>
{
try
{
try
{
await
initContainer
(
foreground
,
mode
);
await
initContainer
();
const
restServer
:
RestServer
=
component
.
get
(
RestServer
);
const
restServer
:
RestServer
=
component
.
get
(
RestServer
);
await
restServer
.
start
();
await
restServer
.
start
();
}
catch
(
err
)
{
}
catch
(
err
)
{
...
...
ts/nni_manager/package.json
View file @
13dc0f8f
...
@@ -33,7 +33,8 @@
...
@@ -33,7 +33,8 @@
"ts-deferred"
:
"^1.0.4"
,
"ts-deferred"
:
"^1.0.4"
,
"typescript-ioc"
:
"^1.2.6"
,
"typescript-ioc"
:
"^1.2.6"
,
"typescript-string-operations"
:
"^1.4.1"
,
"typescript-string-operations"
:
"^1.4.1"
,
"ws"
:
"^7.4.6"
"ws"
:
"^7.4.6"
,
"yargs"
:
"^17.3.1"
},
},
"devDependencies"
:
{
"devDependencies"
:
{
"@types/chai"
:
"^4.2.18"
,
"@types/chai"
:
"^4.2.18"
,
...
@@ -55,6 +56,7 @@
...
@@ -55,6 +56,7 @@
"@types/tar"
:
"^4.0.4"
,
"@types/tar"
:
"^4.0.4"
,
"@types/tmp"
:
"^0.2.0"
,
"@types/tmp"
:
"^0.2.0"
,
"@types/ws"
:
"^7.4.4"
,
"@types/ws"
:
"^7.4.4"
,
"@types/yargs"
:
"^17.0.8"
,
"@typescript-eslint/eslint-plugin"
:
"^2.10.0"
,
"@typescript-eslint/eslint-plugin"
:
"^2.10.0"
,
"@typescript-eslint/parser"
:
"^4.26.0"
,
"@typescript-eslint/parser"
:
"^4.26.0"
,
"chai"
:
"^4.3.4"
,
"chai"
:
"^4.3.4"
,
...
...
ts/nni_manager/test/common/globals/arguments.test.ts
0 → 100644
View file @
13dc0f8f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import
assert
from
'
assert/strict
'
;
import
{
parseArgs
}
from
'
common/globals/arguments
'
;
const
command
=
'
--port 80 --experiment-id ID --action resume --experiments-directory DIR --log-level error
'
;
const
expected
=
{
port
:
80
,
experimentId
:
'
ID
'
,
action
:
'
resume
'
,
experimentsDirectory
:
'
DIR
'
,
logLevel
:
'
error
'
,
foreground
:
false
,
urlPrefix
:
''
,
mode
:
''
,
dispatcherPipe
:
undefined
,
};
function
testGoodShort
():
void
{
const
args
=
parseArgs
(
command
.
split
(
'
'
));
assert
.
deepEqual
(
args
,
expected
);
}
function
testGoodLong
():
void
{
const
cmd
=
command
+
'
--url-prefix URL/prefix --foreground true
'
;
const
args
=
parseArgs
(
cmd
.
split
(
'
'
));
const
expectedLong
=
Object
.
assign
({},
expected
);
expectedLong
.
urlPrefix
=
'
URL/prefix
'
;
expectedLong
.
foreground
=
true
;
assert
.
deepEqual
(
args
,
expectedLong
);
}
function
testBadKey
():
void
{
const
cmd
=
command
+
'
--bad 1
'
;
assert
.
throws
(()
=>
parseArgs
(
cmd
.
split
(
'
'
)));
}
function
testBadPos
():
void
{
const
cmd
=
command
.
replace
(
'
--port
'
,
'
port
'
);
assert
.
throws
(()
=>
parseArgs
(
cmd
.
split
(
'
'
)));
}
function
testBadNum
():
void
{
const
cmd
=
command
.
replace
(
'
80
'
,
'
8o
'
);
assert
.
throws
(()
=>
parseArgs
(
cmd
.
split
(
'
'
)));
}
function
testBadBool
():
void
{
const
cmd
=
command
+
'
--foreground 1
'
;
assert
.
throws
(()
=>
parseArgs
(
cmd
.
split
(
'
'
)));
}
function
testBadChoice
():
void
{
const
cmd
=
command
.
replace
(
'
resume
'
,
'
new
'
);
assert
.
throws
(()
=>
parseArgs
(
cmd
.
split
(
'
'
)));
}
describe
(
'
## globals.arguments ##
'
,
()
=>
{
it
(
'
good short
'
,
()
=>
testGoodShort
());
it
(
'
good long
'
,
()
=>
testGoodLong
());
it
(
'
bad key arg
'
,
()
=>
testBadKey
());
it
(
'
bad positional arg
'
,
()
=>
testBadPos
());
it
(
'
bad number
'
,
()
=>
testBadNum
());
it
(
'
bad boolean
'
,
()
=>
testBadBool
());
it
(
'
bad choice
'
,
()
=>
testBadChoice
());
});
ts/nni_manager/training_service/reusable/trialDispatcher.ts
View file @
13dc0f8f
...
@@ -507,6 +507,7 @@ class TrialDispatcher implements TrainingService {
...
@@ -507,6 +507,7 @@ class TrialDispatcher implements TrainingService {
throw
new
Error
(
`
${
environment
.
id
}
does not has environment service!`
);
throw
new
Error
(
`
${
environment
.
id
}
does not has environment service!`
);
}
}
await
environment
.
environmentService
.
stopEnvironment
(
environment
);
await
environment
.
environmentService
.
stopEnvironment
(
environment
);
liveEnvironmentsCount
--
;
continue
;
continue
;
}
}
...
...
ts/nni_manager/yarn.lock
View file @
13dc0f8f
...
@@ -820,6 +820,18 @@
...
@@ -820,6 +820,18 @@
dependencies:
dependencies:
"@types/node" "*"
"@types/node" "*"
"@types/yargs-parser@*":
version "20.2.1"
resolved "https://registry.yarnpkg.com/@types/yargs-parser/-/yargs-parser-20.2.1.tgz#3b9ce2489919d9e4fea439b76916abc34b2df129"
integrity sha512-7tFImggNeNBVMsn0vLrpn1H1uPrUBdnARPTpZoitY37ZrdJREzf7I16tMrlK3hen349gr1NYh8CmZQa7CTG6Aw==
"@types/yargs@^17.0.8":
version "17.0.8"
resolved "https://registry.yarnpkg.com/@types/yargs/-/yargs-17.0.8.tgz#d23a3476fd3da8a0ea44b5494ca7fa677b9dad4c"
integrity sha512-wDeUwiUmem9FzsyysEwRukaEdDNcwbROvQ9QGRKaLI6t+IltNzbn4/i4asmB10auvZGQCzSQ6t0GSczEThlUXw==
dependencies:
"@types/yargs-parser" "*"
"@typescript-eslint/eslint-plugin@^2.10.0":
"@typescript-eslint/eslint-plugin@^2.10.0":
version "2.34.0"
version "2.34.0"
resolved "https://registry.yarnpkg.com/@typescript-eslint/eslint-plugin/-/eslint-plugin-2.34.0.tgz#6f8ce8a46c7dea4a6f1d171d2bb8fbae6dac2be9"
resolved "https://registry.yarnpkg.com/@typescript-eslint/eslint-plugin/-/eslint-plugin-2.34.0.tgz#6f8ce8a46c7dea4a6f1d171d2bb8fbae6dac2be9"
...
@@ -5752,7 +5764,7 @@ yallist@^4.0.0:
...
@@ -5752,7 +5764,7 @@ yallist@^4.0.0:
resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72"
resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72"
integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==
integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==
yargs-parser@20.2.4, yargs-parser@>=20.2.7, yargs-parser@^18.1.2, yargs-parser@^20.2.2:
yargs-parser@20.2.4, yargs-parser@>=20.2.7, yargs-parser@^18.1.2, yargs-parser@^20.2.2
, yargs-parser@^21.0.0
:
version "20.2.7"
version "20.2.7"
resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.7.tgz#61df85c113edfb5a7a4e36eb8aa60ef423cbc90a"
resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.7.tgz#61df85c113edfb5a7a4e36eb8aa60ef423cbc90a"
integrity sha512-FiNkvbeHzB/syOjIUxFDCnhSfzAL8R5vs40MgLFBorXACCOAEaWu0gRZl14vG8MR9AOJIZbmkjhusqBYZ3HTHw==
integrity sha512-FiNkvbeHzB/syOjIUxFDCnhSfzAL8R5vs40MgLFBorXACCOAEaWu0gRZl14vG8MR9AOJIZbmkjhusqBYZ3HTHw==
...
@@ -5797,6 +5809,19 @@ yargs@^15.0.2:
...
@@ -5797,6 +5809,19 @@ yargs@^15.0.2:
y18n "^4.0.0"
y18n "^4.0.0"
yargs-parser "^18.1.2"
yargs-parser "^18.1.2"
yargs@^17.3.1:
version "17.3.1"
resolved "https://registry.yarnpkg.com/yargs/-/yargs-17.3.1.tgz#da56b28f32e2fd45aefb402ed9c26f42be4c07b9"
integrity sha512-WUANQeVgjLbNsEmGk20f+nlHgOqzRFpiGWVaBrYGYIGANIIu3lWjoyi0fNlFmJkvfhCZ6BXINe7/W2O2bV4iaA==
dependencies:
cliui "^7.0.2"
escalade "^3.1.1"
get-caller-file "^2.0.5"
require-directory "^2.1.1"
string-width "^4.2.3"
y18n "^5.0.5"
yargs-parser "^21.0.0"
yn@3.1.1:
yn@3.1.1:
version "3.1.1"
version "3.1.1"
resolved "https://registry.yarnpkg.com/yn/-/yn-3.1.1.tgz#1e87401a09d767c1d5eab26a6e4c185182d2eb50"
resolved "https://registry.yarnpkg.com/yn/-/yn-3.1.1.tgz#1e87401a09d767c1d5eab26a6e4c185182d2eb50"
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment