"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "537d37c26c619a9aea48bcef44c1d8f45d6d7b1a"
Unverified Commit a31d37e5 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Enable trial version check in Retiarii (#4738)

parent a801b5f8
...@@ -139,7 +139,7 @@ if __name__ == '__main__': ...@@ -139,7 +139,7 @@ if __name__ == '__main__':
# exp_config.execution_engine = 'base' # exp_config.execution_engine = 'base'
# export_formatter = 'code' # export_formatter = 'code'
exp.run(exp_config, 8081 + random.randint(0, 100)) exp.run(exp_config, 8080)
print('Final model:') print('Final model:')
for model_code in exp.export_top_models(formatter=export_formatter): for model_code in exp.export_top_models(formatter=export_formatter):
print(model_code) print(model_code)
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import logging import logging
import sys
import warnings
import cloudpickle
import json_tricks
import numpy
import yaml
import nni
def _minor_version_tuple(version_str: str) -> tuple[int, int]:
# If not a number, returns -1 (e.g., 999.dev0 -> (999, -1))
return tuple(int(x) if x.isdigit() else -1 for x in version_str.split(".")[:2])
PYTHON_VERSION = sys.version_info[:2]
NUMPY_VERSION = _minor_version_tuple(numpy.__version__)
try: try:
import torch import torch
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION = _minor_version_tuple(torch.__version__)
except Exception: except ImportError:
logging.info("PyTorch is not installed.") logging.info("PyTorch is not installed.")
TORCH_VERSION = None TORCH_VERSION = None
try:
import pytorch_lightning
PYTORCH_LIGHTNING_VERSION = _minor_version_tuple(pytorch_lightning.__version__)
except ImportError:
logging.info("PyTorch Lightning is not installed.")
PYTORCH_LIGHTNING_VERSION = None
try:
import tensorflow
TENSORFLOW_VERSION = _minor_version_tuple(tensorflow.__version__)
except ImportError:
logging.info("Tensorflow is not installed.")
TENSORFLOW_VERSION = None
# Serialization version check are needed because they are prone to be inconsistent between versions
CLOUDPICKLE_VERSION = _minor_version_tuple(cloudpickle.__version__)
JSON_TRICKS_VERSION = _minor_version_tuple(json_tricks.__version__)
PYYAML_VERSION = _minor_version_tuple(yaml.__version__)
NNI_VERSION = _minor_version_tuple(nni.__version__)
def version_dump() -> dict[str, tuple[int, int] | None]:
return {
'python': PYTHON_VERSION,
'numpy': NUMPY_VERSION,
'torch': TORCH_VERSION,
'pytorch_lightning': PYTORCH_LIGHTNING_VERSION,
'tensorflow': TENSORFLOW_VERSION,
'cloudpickle': CLOUDPICKLE_VERSION,
'json_tricks': JSON_TRICKS_VERSION,
'pyyaml': PYYAML_VERSION,
'nni': NNI_VERSION
}
def version_check(expect: dict, raise_error: bool = False) -> None:
current_ver = version_dump()
for package in expect:
# version could be list due to serialization
exp_version: tuple | None = tuple(expect[package]) if expect[package] else None
if exp_version is None:
continue
err_message: str | None = None
if package not in current_ver:
err_message = f'{package} is missing in current environment'
elif current_ver[package] != exp_version:
err_message = f'Expect {package} to have version {exp_version}, but {current_ver[package]} found'
if err_message:
if raise_error:
raise RuntimeError('Version check failed: ' + err_message)
else:
warnings.warn('Version check with warning: ' + err_message)
...@@ -7,6 +7,7 @@ from typing import Any, Callable ...@@ -7,6 +7,7 @@ from typing import Any, Callable
import nni import nni
from nni.common.serializer import PayloadTooLarge from nni.common.serializer import PayloadTooLarge
from nni.common.version import version_dump
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType from nni.utils import MetricType
...@@ -120,7 +121,8 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -120,7 +121,8 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameter_id': self.parameters_count, 'parameter_id': self.parameters_count,
'parameters': parameters, 'parameters': parameters,
'parameter_source': 'algorithm', 'parameter_source': 'algorithm',
'placement_constraint': placement_constraint 'placement_constraint': placement_constraint,
'version_info': version_dump()
} }
_logger.debug('New trial sent: %s', new_trial) _logger.debug('New trial sent: %s', new_trial)
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import warnings
from typing import NewType, Any from typing import NewType, Any
import nni import nni
from nni.common.version import version_check
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor # NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import # because it would induce cycled import
...@@ -37,6 +39,14 @@ def receive_trial_parameters() -> dict: ...@@ -37,6 +39,14 @@ def receive_trial_parameters() -> dict:
Reload with our json loads because NNI didn't use Retiarii serializer to load the data. Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
""" """
params = nni.get_next_parameter() params = nni.get_next_parameter()
# version check, optional
raw_params = nni.trial._params
if raw_params is not None and 'version_info' in raw_params:
version_check(raw_params['version_info'])
else:
warnings.warn('Version check failed because `version_info` is not found.')
return params return params
......
import pytest
import sys
from nni.common.version import version_dump, version_check
def test_version_dump():
dump_ver = version_dump()
assert len(dump_ver) >= 9
print(dump_ver)
def test_version_check():
version_check(version_dump(), raise_error=True)
version_check({'python': sys.version_info[:2]}, raise_error=True)
with pytest.warns(UserWarning):
version_check({'nni': (99999, 99999)})
with pytest.raises(RuntimeError):
version_check({'python': (2, 7)}, raise_error=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment