"test/gemm/gemm_dl_fp32.cpp" did not exist on "3ba149328f2704e096b2eed7ffeacff0b54fdc8b"
Unverified Commit a31d37e5 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Enable trial version check in Retiarii (#4738)

parent a801b5f8
......@@ -139,7 +139,7 @@ if __name__ == '__main__':
# exp_config.execution_engine = 'base'
# export_formatter = 'code'
exp.run(exp_config, 8081 + random.randint(0, 100))
exp.run(exp_config, 8080)
print('Final model:')
for model_code in exp.export_top_models(formatter=export_formatter):
print(model_code)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
import sys
import warnings
import cloudpickle
import json_tricks
import numpy
import yaml
import nni
def _minor_version_tuple(version_str: str) -> tuple[int, int]:
# If not a number, returns -1 (e.g., 999.dev0 -> (999, -1))
return tuple(int(x) if x.isdigit() else -1 for x in version_str.split(".")[:2])
PYTHON_VERSION = sys.version_info[:2]
NUMPY_VERSION = _minor_version_tuple(numpy.__version__)
try:
import torch
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
except Exception:
TORCH_VERSION = _minor_version_tuple(torch.__version__)
except ImportError:
logging.info("PyTorch is not installed.")
TORCH_VERSION = None
try:
import pytorch_lightning
PYTORCH_LIGHTNING_VERSION = _minor_version_tuple(pytorch_lightning.__version__)
except ImportError:
logging.info("PyTorch Lightning is not installed.")
PYTORCH_LIGHTNING_VERSION = None
try:
import tensorflow
TENSORFLOW_VERSION = _minor_version_tuple(tensorflow.__version__)
except ImportError:
logging.info("Tensorflow is not installed.")
TENSORFLOW_VERSION = None
# Serialization version check are needed because they are prone to be inconsistent between versions
CLOUDPICKLE_VERSION = _minor_version_tuple(cloudpickle.__version__)
JSON_TRICKS_VERSION = _minor_version_tuple(json_tricks.__version__)
PYYAML_VERSION = _minor_version_tuple(yaml.__version__)
NNI_VERSION = _minor_version_tuple(nni.__version__)
def version_dump() -> dict[str, tuple[int, int] | None]:
return {
'python': PYTHON_VERSION,
'numpy': NUMPY_VERSION,
'torch': TORCH_VERSION,
'pytorch_lightning': PYTORCH_LIGHTNING_VERSION,
'tensorflow': TENSORFLOW_VERSION,
'cloudpickle': CLOUDPICKLE_VERSION,
'json_tricks': JSON_TRICKS_VERSION,
'pyyaml': PYYAML_VERSION,
'nni': NNI_VERSION
}
def version_check(expect: dict, raise_error: bool = False) -> None:
current_ver = version_dump()
for package in expect:
# version could be list due to serialization
exp_version: tuple | None = tuple(expect[package]) if expect[package] else None
if exp_version is None:
continue
err_message: str | None = None
if package not in current_ver:
err_message = f'{package} is missing in current environment'
elif current_ver[package] != exp_version:
err_message = f'Expect {package} to have version {exp_version}, but {current_ver[package]} found'
if err_message:
if raise_error:
raise RuntimeError('Version check failed: ' + err_message)
else:
warnings.warn('Version check with warning: ' + err_message)
......@@ -7,6 +7,7 @@ from typing import Any, Callable
import nni
from nni.common.serializer import PayloadTooLarge
from nni.common.version import version_dump
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType
......@@ -120,7 +121,8 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameter_id': self.parameters_count,
'parameters': parameters,
'parameter_source': 'algorithm',
'placement_constraint': placement_constraint
'placement_constraint': placement_constraint,
'version_info': version_dump()
}
_logger.debug('New trial sent: %s', new_trial)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import warnings
from typing import NewType, Any
import nni
from nni.common.version import version_check
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
......@@ -37,6 +39,14 @@ def receive_trial_parameters() -> dict:
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
"""
params = nni.get_next_parameter()
# version check, optional
raw_params = nni.trial._params
if raw_params is not None and 'version_info' in raw_params:
version_check(raw_params['version_info'])
else:
warnings.warn('Version check failed because `version_info` is not found.')
return params
......
import pytest
import sys
from nni.common.version import version_dump, version_check
def test_version_dump():
dump_ver = version_dump()
assert len(dump_ver) >= 9
print(dump_ver)
def test_version_check():
version_check(version_dump(), raise_error=True)
version_check({'python': sys.version_info[:2]}, raise_error=True)
with pytest.warns(UserWarning):
version_check({'nni': (99999, 99999)})
with pytest.raises(RuntimeError):
version_check({'python': (2, 7)}, raise_error=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment