Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a31d37e5
"vscode:/vscode.git/clone" did not exist on "12f4cfce96a8dab6ff0e790ae9028d39ee88e303"
Unverified
Commit
a31d37e5
authored
Apr 07, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 07, 2022
Browse files
Enable trial version check in Retiarii (#4738)
parent
a801b5f8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
110 additions
and
4 deletions
+110
-4
examples/nas/multi-trial/mnist/search.py
examples/nas/multi-trial/mnist/search.py
+1
-1
nni/common/version.py
nni/common/version.py
+75
-2
nni/retiarii/integration.py
nni/retiarii/integration.py
+3
-1
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+10
-0
test/ut/sdk/test_version.py
test/ut/sdk/test_version.py
+21
-0
No files found.
examples/nas/multi-trial/mnist/search.py
View file @
a31d37e5
...
@@ -139,7 +139,7 @@ if __name__ == '__main__':
...
@@ -139,7 +139,7 @@ if __name__ == '__main__':
# exp_config.execution_engine = 'base'
# exp_config.execution_engine = 'base'
# export_formatter = 'code'
# export_formatter = 'code'
exp
.
run
(
exp_config
,
808
1
+
random
.
randint
(
0
,
100
)
)
exp
.
run
(
exp_config
,
808
0
)
print
(
'Final model:'
)
print
(
'Final model:'
)
for
model_code
in
exp
.
export_top_models
(
formatter
=
export_formatter
):
for
model_code
in
exp
.
export_top_models
(
formatter
=
export_formatter
):
print
(
model_code
)
print
(
model_code
)
nni/common/version.py
View file @
a31d37e5
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
import
logging
import
logging
import
sys
import
warnings
import
cloudpickle
import
json_tricks
import
numpy
import
yaml
import
nni
def
_minor_version_tuple
(
version_str
:
str
)
->
tuple
[
int
,
int
]:
# If not a number, returns -1 (e.g., 999.dev0 -> (999, -1))
return
tuple
(
int
(
x
)
if
x
.
isdigit
()
else
-
1
for
x
in
version_str
.
split
(
"."
)[:
2
])
PYTHON_VERSION
=
sys
.
version_info
[:
2
]
NUMPY_VERSION
=
_minor_version_tuple
(
numpy
.
__version__
)
try
:
try
:
import
torch
import
torch
TORCH_VERSION
=
tuple
(
int
(
x
)
for
x
in
torch
.
__version__
.
split
(
"."
)[:
2
]
)
TORCH_VERSION
=
_minor_version_tuple
(
torch
.
__version__
)
except
Exception
:
except
ImportError
:
logging
.
info
(
"PyTorch is not installed."
)
logging
.
info
(
"PyTorch is not installed."
)
TORCH_VERSION
=
None
TORCH_VERSION
=
None
try
:
import
pytorch_lightning
PYTORCH_LIGHTNING_VERSION
=
_minor_version_tuple
(
pytorch_lightning
.
__version__
)
except
ImportError
:
logging
.
info
(
"PyTorch Lightning is not installed."
)
PYTORCH_LIGHTNING_VERSION
=
None
try
:
import
tensorflow
TENSORFLOW_VERSION
=
_minor_version_tuple
(
tensorflow
.
__version__
)
except
ImportError
:
logging
.
info
(
"Tensorflow is not installed."
)
TENSORFLOW_VERSION
=
None
# Serialization version check are needed because they are prone to be inconsistent between versions
CLOUDPICKLE_VERSION
=
_minor_version_tuple
(
cloudpickle
.
__version__
)
JSON_TRICKS_VERSION
=
_minor_version_tuple
(
json_tricks
.
__version__
)
PYYAML_VERSION
=
_minor_version_tuple
(
yaml
.
__version__
)
NNI_VERSION
=
_minor_version_tuple
(
nni
.
__version__
)
def
version_dump
()
->
dict
[
str
,
tuple
[
int
,
int
]
|
None
]:
return
{
'python'
:
PYTHON_VERSION
,
'numpy'
:
NUMPY_VERSION
,
'torch'
:
TORCH_VERSION
,
'pytorch_lightning'
:
PYTORCH_LIGHTNING_VERSION
,
'tensorflow'
:
TENSORFLOW_VERSION
,
'cloudpickle'
:
CLOUDPICKLE_VERSION
,
'json_tricks'
:
JSON_TRICKS_VERSION
,
'pyyaml'
:
PYYAML_VERSION
,
'nni'
:
NNI_VERSION
}
def
version_check
(
expect
:
dict
,
raise_error
:
bool
=
False
)
->
None
:
current_ver
=
version_dump
()
for
package
in
expect
:
# version could be list due to serialization
exp_version
:
tuple
|
None
=
tuple
(
expect
[
package
])
if
expect
[
package
]
else
None
if
exp_version
is
None
:
continue
err_message
:
str
|
None
=
None
if
package
not
in
current_ver
:
err_message
=
f
'
{
package
}
is missing in current environment'
elif
current_ver
[
package
]
!=
exp_version
:
err_message
=
f
'Expect
{
package
}
to have version
{
exp_version
}
, but
{
current_ver
[
package
]
}
found'
if
err_message
:
if
raise_error
:
raise
RuntimeError
(
'Version check failed: '
+
err_message
)
else
:
warnings
.
warn
(
'Version check with warning: '
+
err_message
)
nni/retiarii/integration.py
View file @
a31d37e5
...
@@ -7,6 +7,7 @@ from typing import Any, Callable
...
@@ -7,6 +7,7 @@ from typing import Any, Callable
import
nni
import
nni
from
nni.common.serializer
import
PayloadTooLarge
from
nni.common.serializer
import
PayloadTooLarge
from
nni.common.version
import
version_dump
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.utils
import
MetricType
from
nni.utils
import
MetricType
...
@@ -120,7 +121,8 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -120,7 +121,8 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameter_id'
:
self
.
parameters_count
,
'parameter_id'
:
self
.
parameters_count
,
'parameters'
:
parameters
,
'parameters'
:
parameters
,
'parameter_source'
:
'algorithm'
,
'parameter_source'
:
'algorithm'
,
'placement_constraint'
:
placement_constraint
'placement_constraint'
:
placement_constraint
,
'version_info'
:
version_dump
()
}
}
_logger
.
debug
(
'New trial sent: %s'
,
new_trial
)
_logger
.
debug
(
'New trial sent: %s'
,
new_trial
)
...
...
nni/retiarii/integration_api.py
View file @
a31d37e5
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
warnings
from
typing
import
NewType
,
Any
from
typing
import
NewType
,
Any
import
nni
import
nni
from
nni.common.version
import
version_check
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
# because it would induce cycled import
...
@@ -37,6 +39,14 @@ def receive_trial_parameters() -> dict:
...
@@ -37,6 +39,14 @@ def receive_trial_parameters() -> dict:
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
"""
"""
params
=
nni
.
get_next_parameter
()
params
=
nni
.
get_next_parameter
()
# version check, optional
raw_params
=
nni
.
trial
.
_params
if
raw_params
is
not
None
and
'version_info'
in
raw_params
:
version_check
(
raw_params
[
'version_info'
])
else
:
warnings
.
warn
(
'Version check failed because `version_info` is not found.'
)
return
params
return
params
...
...
test/ut/sdk/test_version.py
0 → 100644
View file @
a31d37e5
import
pytest
import
sys
from
nni.common.version
import
version_dump
,
version_check
def
test_version_dump
():
dump_ver
=
version_dump
()
assert
len
(
dump_ver
)
>=
9
print
(
dump_ver
)
def
test_version_check
():
version_check
(
version_dump
(),
raise_error
=
True
)
version_check
({
'python'
:
sys
.
version_info
[:
2
]},
raise_error
=
True
)
with
pytest
.
warns
(
UserWarning
):
version_check
({
'nni'
:
(
99999
,
99999
)})
with
pytest
.
raises
(
RuntimeError
):
version_check
({
'python'
:
(
2
,
7
)},
raise_error
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment