Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a31d37e5
"vscode:/vscode.git/clone" did not exist on "5fe29b063f7877036e54e26d85c21e2d1d51a2c8"
Unverified
Commit
a31d37e5
authored
Apr 07, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 07, 2022
Browse files
Enable trial version check in Retiarii (#4738)
parent
a801b5f8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
110 additions
and
4 deletions
+110
-4
examples/nas/multi-trial/mnist/search.py
examples/nas/multi-trial/mnist/search.py
+1
-1
nni/common/version.py
nni/common/version.py
+75
-2
nni/retiarii/integration.py
nni/retiarii/integration.py
+3
-1
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+10
-0
test/ut/sdk/test_version.py
test/ut/sdk/test_version.py
+21
-0
No files found.
examples/nas/multi-trial/mnist/search.py
View file @
a31d37e5
...
...
@@ -139,7 +139,7 @@ if __name__ == '__main__':
# exp_config.execution_engine = 'base'
# export_formatter = 'code'
exp
.
run
(
exp_config
,
808
1
+
random
.
randint
(
0
,
100
)
)
exp
.
run
(
exp_config
,
808
0
)
print
(
'Final model:'
)
for
model_code
in
exp
.
export_top_models
(
formatter
=
export_formatter
):
print
(
model_code
)
nni/common/version.py
View file @
a31d37e5
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
import
logging
import
sys
import
warnings
import
cloudpickle
import
json_tricks
import
numpy
import
yaml
import
nni
def
_minor_version_tuple
(
version_str
:
str
)
->
tuple
[
int
,
int
]:
# If not a number, returns -1 (e.g., 999.dev0 -> (999, -1))
return
tuple
(
int
(
x
)
if
x
.
isdigit
()
else
-
1
for
x
in
version_str
.
split
(
"."
)[:
2
])
PYTHON_VERSION
=
sys
.
version_info
[:
2
]
NUMPY_VERSION
=
_minor_version_tuple
(
numpy
.
__version__
)
try
:
import
torch
TORCH_VERSION
=
tuple
(
int
(
x
)
for
x
in
torch
.
__version__
.
split
(
"."
)[:
2
]
)
except
Exception
:
TORCH_VERSION
=
_minor_version_tuple
(
torch
.
__version__
)
except
ImportError
:
logging
.
info
(
"PyTorch is not installed."
)
TORCH_VERSION
=
None
try
:
import
pytorch_lightning
PYTORCH_LIGHTNING_VERSION
=
_minor_version_tuple
(
pytorch_lightning
.
__version__
)
except
ImportError
:
logging
.
info
(
"PyTorch Lightning is not installed."
)
PYTORCH_LIGHTNING_VERSION
=
None
try
:
import
tensorflow
TENSORFLOW_VERSION
=
_minor_version_tuple
(
tensorflow
.
__version__
)
except
ImportError
:
logging
.
info
(
"Tensorflow is not installed."
)
TENSORFLOW_VERSION
=
None
# Serialization version check are needed because they are prone to be inconsistent between versions
CLOUDPICKLE_VERSION
=
_minor_version_tuple
(
cloudpickle
.
__version__
)
JSON_TRICKS_VERSION
=
_minor_version_tuple
(
json_tricks
.
__version__
)
PYYAML_VERSION
=
_minor_version_tuple
(
yaml
.
__version__
)
NNI_VERSION
=
_minor_version_tuple
(
nni
.
__version__
)
def
version_dump
()
->
dict
[
str
,
tuple
[
int
,
int
]
|
None
]:
return
{
'python'
:
PYTHON_VERSION
,
'numpy'
:
NUMPY_VERSION
,
'torch'
:
TORCH_VERSION
,
'pytorch_lightning'
:
PYTORCH_LIGHTNING_VERSION
,
'tensorflow'
:
TENSORFLOW_VERSION
,
'cloudpickle'
:
CLOUDPICKLE_VERSION
,
'json_tricks'
:
JSON_TRICKS_VERSION
,
'pyyaml'
:
PYYAML_VERSION
,
'nni'
:
NNI_VERSION
}
def
version_check
(
expect
:
dict
,
raise_error
:
bool
=
False
)
->
None
:
current_ver
=
version_dump
()
for
package
in
expect
:
# version could be list due to serialization
exp_version
:
tuple
|
None
=
tuple
(
expect
[
package
])
if
expect
[
package
]
else
None
if
exp_version
is
None
:
continue
err_message
:
str
|
None
=
None
if
package
not
in
current_ver
:
err_message
=
f
'
{
package
}
is missing in current environment'
elif
current_ver
[
package
]
!=
exp_version
:
err_message
=
f
'Expect
{
package
}
to have version
{
exp_version
}
, but
{
current_ver
[
package
]
}
found'
if
err_message
:
if
raise_error
:
raise
RuntimeError
(
'Version check failed: '
+
err_message
)
else
:
warnings
.
warn
(
'Version check with warning: '
+
err_message
)
nni/retiarii/integration.py
View file @
a31d37e5
...
...
@@ -7,6 +7,7 @@ from typing import Any, Callable
import
nni
from
nni.common.serializer
import
PayloadTooLarge
from
nni.common.version
import
version_dump
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.utils
import
MetricType
...
...
@@ -120,7 +121,8 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameter_id'
:
self
.
parameters_count
,
'parameters'
:
parameters
,
'parameter_source'
:
'algorithm'
,
'placement_constraint'
:
placement_constraint
'placement_constraint'
:
placement_constraint
,
'version_info'
:
version_dump
()
}
_logger
.
debug
(
'New trial sent: %s'
,
new_trial
)
...
...
nni/retiarii/integration_api.py
View file @
a31d37e5
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
warnings
from
typing
import
NewType
,
Any
import
nni
from
nni.common.version
import
version_check
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
...
...
@@ -37,6 +39,14 @@ def receive_trial_parameters() -> dict:
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
"""
params
=
nni
.
get_next_parameter
()
# version check, optional
raw_params
=
nni
.
trial
.
_params
if
raw_params
is
not
None
and
'version_info'
in
raw_params
:
version_check
(
raw_params
[
'version_info'
])
else
:
warnings
.
warn
(
'Version check failed because `version_info` is not found.'
)
return
params
...
...
test/ut/sdk/test_version.py
0 → 100644
View file @
a31d37e5
import
pytest
import
sys
from
nni.common.version
import
version_dump
,
version_check
def
test_version_dump
():
dump_ver
=
version_dump
()
assert
len
(
dump_ver
)
>=
9
print
(
dump_ver
)
def
test_version_check
():
version_check
(
version_dump
(),
raise_error
=
True
)
version_check
({
'python'
:
sys
.
version_info
[:
2
]},
raise_error
=
True
)
with
pytest
.
warns
(
UserWarning
):
version_check
({
'nni'
:
(
99999
,
99999
)})
with
pytest
.
raises
(
RuntimeError
):
version_check
({
'python'
:
(
2
,
7
)},
raise_error
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment