Unverified Commit 18962129 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add license header and typehints for NAS (#4774)

parent 8c2f717d
...@@ -19,5 +19,5 @@ scikit-learn >= 0.24.1 ...@@ -19,5 +19,5 @@ scikit-learn >= 0.24.1
scipy < 1.8 ; python_version < "3.8" scipy < 1.8 ; python_version < "3.8"
scipy ; python_version >= "3.8" scipy ; python_version >= "3.8"
typeguard typeguard
typing_extensions >= 4.0.0 ; python_version < "3.8" typing_extensions >= 4.0.0
websockets >= 10.1 websockets >= 10.1
...@@ -13,7 +13,7 @@ import sys ...@@ -13,7 +13,7 @@ import sys
import types import types
import warnings import warnings
from io import IOBase from io import IOBase
from typing import Any, Dict, List, Optional, TypeVar, Union from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast, Generic
import cloudpickle # use cloudpickle as backend for unserializable types and instances import cloudpickle # use cloudpickle as backend for unserializable types and instances
import json_tricks # use json_tricks as serializer backend import json_tricks # use json_tricks as serializer backend
...@@ -115,7 +115,7 @@ def is_wrapped_with_trace(cls_or_func: Any) -> bool: ...@@ -115,7 +115,7 @@ def is_wrapped_with_trace(cls_or_func: Any) -> bool:
) )
class SerializableObject(Traceable): class SerializableObject(Generic[T], Traceable):
""" """
Serializable object is a wrapper of existing python objects, that supports dump and load easily. Serializable object is a wrapper of existing python objects, that supports dump and load easily.
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``. Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
...@@ -147,7 +147,7 @@ class SerializableObject(Traceable): ...@@ -147,7 +147,7 @@ class SerializableObject(Traceable):
# Reinitialize # Reinitialize
return trace(self.trace_symbol)(*self.trace_args, **self.trace_kwargs) return trace(self.trace_symbol)(*self.trace_args, **self.trace_kwargs)
return self return cast(T, self)
@property @property
def trace_symbol(self) -> Any: def trace_symbol(self) -> Any:
...@@ -187,7 +187,7 @@ class SerializableObject(Traceable): ...@@ -187,7 +187,7 @@ class SerializableObject(Traceable):
')' ')'
def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> Any: def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> T:
# If an object is already created, this can be a fix so that the necessary info are re-injected into the object. # If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
# Make obj complying with the interface of traceable, though we cannot change its base class. # Make obj complying with the interface of traceable, though we cannot change its base class.
obj.__dict__.update(_nni_symbol=symbol, _nni_args=args, _nni_kwargs=kwargs) obj.__dict__.update(_nni_symbol=symbol, _nni_args=args, _nni_kwargs=kwargs)
...@@ -233,11 +233,11 @@ def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T: ...@@ -233,11 +233,11 @@ def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T:
else: else:
# sometimes create_wrapper is mandatory, e.g., for built-in types like list/int. # sometimes create_wrapper is mandatory, e.g., for built-in types like list/int.
# but I don't want to check here because it's unreliable. # but I don't want to check here because it's unreliable.
wrapper = type('wrapper', (Traceable, cls), attributes) wrapper = type('wrapper', (Traceable, cast(Type, cls)), attributes)
return wrapper return cast(T, wrapper)
def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = False) -> Union[T, Traceable]: def trace(cls_or_func: T = cast(T, None), *, kw_only: bool = True, inheritable: bool = False) -> T:
""" """
Annotate a function or a class if you want to preserve where it comes from. Annotate a function or a class if you want to preserve where it comes from.
This is usually used in the following scenarios: This is usually used in the following scenarios:
...@@ -283,7 +283,7 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa ...@@ -283,7 +283,7 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa
# Might be changed in future. # Might be changed in future.
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '') nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable': if nni_trace_flag.lower() == 'disable':
return cls_or_func return cast(T, cls_or_func)
def wrap(cls_or_func): def wrap(cls_or_func):
# already annotated, do nothing # already annotated, do nothing
...@@ -301,20 +301,22 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa ...@@ -301,20 +301,22 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa
# if we're being called as @trace() # if we're being called as @trace()
if cls_or_func is None: if cls_or_func is None:
return wrap return wrap # type: ignore
# if we are called without parentheses # if we are called without parentheses
return wrap(cls_or_func) return wrap(cls_or_func) # type: ignore
def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_size_limit: int = 4096, def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_size_limit: int = 4096,
allow_nan: bool = True, **json_tricks_kwargs) -> Union[str, bytes]: allow_nan: bool = True, **json_tricks_kwargs) -> str:
""" """
Convert a nested data structure to a json string. Save to file if fp is specified. Convert a nested data structure to a json string. Save to file if fp is specified.
Use json-tricks as main backend. For unhandled cases in json-tricks, use cloudpickle. Use json-tricks as main backend. For unhandled cases in json-tricks, use cloudpickle.
The serializer is not designed for long-term storage use, but rather to copy data between processes. The serializer is not designed for long-term storage use, but rather to copy data between processes.
The format is also subject to change between NNI releases. The format is also subject to change between NNI releases.
To compress the payload, please use :func:`dump_bytes`.
Parameters Parameters
---------- ----------
obj : any obj : any
...@@ -334,6 +336,39 @@ def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_s ...@@ -334,6 +336,39 @@ def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_s
Normally str. Sometimes bytes (if compressed). Normally str. Sometimes bytes (if compressed).
""" """
if json_tricks_kwargs.get('compression') is not None:
raise ValueError('If you meant to compress the dumped payload, please use `dump_bytes`.')
result = _dump(
obj=obj,
fp=fp,
use_trace=use_trace,
pickle_size_limit=pickle_size_limit,
allow_nan=allow_nan,
**json_tricks_kwargs)
return cast(str, result)
def dump_bytes(obj: Any, fp: Optional[Any] = None, *, compression: int = cast(int, None),
use_trace: bool = True, pickle_size_limit: int = 4096,
allow_nan: bool = True, **json_tricks_kwargs) -> bytes:
"""
Same as :func:`dump`, but to comporess payload, with `compression <https://json-tricks.readthedocs.io/en/stable/#dump>`__.
"""
if compression is None:
raise ValueError('compression must be set.')
result = _dump(
obj=obj,
fp=fp,
compression=compression,
use_trace=use_trace,
pickle_size_limit=pickle_size_limit,
allow_nan=allow_nan,
**json_tricks_kwargs)
return cast(bytes, result)
def _dump(*, obj: Any, fp: Optional[Any], use_trace: bool, pickle_size_limit: int,
allow_nan: bool, **json_tricks_kwargs) -> Union[str, bytes]:
encoders = [ encoders = [
# we don't need to check for dependency as many of those have already been required by NNI # we don't need to check for dependency as many of those have already been required by NNI
json_tricks.pathlib_encode, # pathlib is a required dependency for NNI json_tricks.pathlib_encode, # pathlib is a required dependency for NNI
...@@ -456,7 +491,7 @@ def _trace_cls(base, kw_only, call_super=True, inheritable=False): ...@@ -456,7 +491,7 @@ def _trace_cls(base, kw_only, call_super=True, inheritable=False):
raise TypeError(f"{base} has a superclass already decorated with trace, and it's using a customized metaclass {type(base)}. " raise TypeError(f"{base} has a superclass already decorated with trace, and it's using a customized metaclass {type(base)}. "
"Please either use the default metaclass, or remove trace from the super-class.") "Please either use the default metaclass, or remove trace from the super-class.")
class wrapper(SerializableObject, base, metaclass=metaclass): class wrapper(SerializableObject, base, metaclass=metaclass): # type: ignore
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# store a copy of initial parameters # store a copy of initial parameters
args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True) args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True)
...@@ -528,7 +563,8 @@ def _trace_func(func, kw_only): ...@@ -528,7 +563,8 @@ def _trace_func(func, kw_only):
# and thus not possible to restore the trace parameters after dump and reload. # and thus not possible to restore the trace parameters after dump and reload.
# this is a known limitation. # this is a known limitation.
new_type = _make_class_traceable(type(res), True) new_type = _make_class_traceable(type(res), True)
res = new_type(res) # re-creating the object # re-creating the object
res = new_type(res) # type: ignore
res = inject_trace_info(res, func, args, kwargs) res = inject_trace_info(res, func, args, kwargs)
else: else:
raise TypeError(f'Try to add trace info to {res}, but the type "{type(res)}" is unknown. ' raise TypeError(f'Try to add trace info to {res}, but the type "{type(res)}" is unknown. '
...@@ -750,7 +786,7 @@ def import_cls_or_func_from_hybrid_name(s: str) -> Any: ...@@ -750,7 +786,7 @@ def import_cls_or_func_from_hybrid_name(s: str) -> Any:
return _import_cls_or_func_from_name(s) return _import_cls_or_func_from_name(s)
def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> str: def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> Dict[str, str]:
if not isinstance(cls_or_func, type) and not _is_function(cls_or_func): if not isinstance(cls_or_func, type) and not _is_function(cls_or_func):
# not a function or class, continue # not a function or class, continue
return cls_or_func return cls_or_func
...@@ -762,8 +798,7 @@ def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, ...@@ -762,8 +798,7 @@ def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False,
def _json_tricks_func_or_cls_decode(s: Dict[str, Any]) -> Any: def _json_tricks_func_or_cls_decode(s: Dict[str, Any]) -> Any:
if isinstance(s, dict) and '__nni_type__' in s: if isinstance(s, dict) and '__nni_type__' in s:
s = s['__nni_type__'] return import_cls_or_func_from_hybrid_name(s['__nni_type__'])
return import_cls_or_func_from_hybrid_name(s)
return s return s
...@@ -815,8 +850,7 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si ...@@ -815,8 +850,7 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si
def _json_tricks_any_object_decode(obj: Dict[str, Any]) -> Any: def _json_tricks_any_object_decode(obj: Dict[str, Any]) -> Any:
if isinstance(obj, dict) and '__nni_obj__' in obj: if isinstance(obj, dict) and '__nni_obj__' in obj:
obj = obj['__nni_obj__'] b = base64.b64decode(obj['__nni_obj__'])
b = base64.b64decode(obj)
return _wrapped_cloudpickle_loads(b) return _wrapped_cloudpickle_loads(b)
return obj return obj
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .utils import load_benchmark, download_benchmark from .utils import load_benchmark, download_benchmark
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os import os
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse import argparse
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .constants import INPUT, OUTPUT, CONV3X3_BN_RELU, CONV1X1_BN_RELU, MAXPOOL3X3 from .constants import INPUT, OUTPUT, CONV3X3_BN_RELU, CONV1X1_BN_RELU, MAXPOOL3X3
from .model import Nb101TrialStats, Nb101IntermediateStats, Nb101TrialConfig from .model import Nb101TrialStats, Nb101IntermediateStats, Nb101TrialConfig
from .query import query_nb101_trial_stats from .query import query_nb101_trial_stats
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
INPUT = 'input' INPUT = 'input'
OUTPUT = 'output' OUTPUT = 'output'
CONV3X3_BN_RELU = 'conv3x3-bn-relu' CONV3X3_BN_RELU = 'conv3x3-bn-relu'
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse import argparse
from tqdm import tqdm from tqdm import tqdm
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import hashlib import hashlib
import numpy as np import numpy as np
...@@ -10,7 +13,7 @@ def _labeling_from_architecture(architecture, vertices): ...@@ -10,7 +13,7 @@ def _labeling_from_architecture(architecture, vertices):
def _adjancency_matrix_from_architecture(architecture, vertices): def _adjancency_matrix_from_architecture(architecture, vertices):
matrix = np.zeros((vertices, vertices), dtype=np.bool) matrix = np.zeros((vertices, vertices), dtype=np.bool) # type: ignore
for i in range(1, vertices): for i in range(1, vertices):
for k in architecture['input{}'.format(i)]: for k in architecture['input{}'.format(i)]:
matrix[k, i] = 1 matrix[k, i] = 1
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField from playhouse.sqlite_ext import JSONField
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools import functools
from peewee import fn from peewee import fn
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3 from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3
from .model import Nb201TrialStats, Nb201IntermediateStats, Nb201TrialConfig from .model import Nb201TrialStats, Nb201IntermediateStats, Nb201TrialConfig
from .query import query_nb201_trial_stats from .query import query_nb201_trial_stats
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
NONE = 'none' NONE = 'none'
SKIP_CONNECT = 'skip_connect' SKIP_CONNECT = 'skip_connect'
CONV_1X1 = 'conv_1x1' CONV_1X1 = 'conv_1x1'
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse import argparse
import re import re
...@@ -17,7 +20,7 @@ def parse_arch_str(arch_str): ...@@ -17,7 +20,7 @@ def parse_arch_str(arch_str):
'nor_conv_3x3': CONV_3X3, 'nor_conv_3x3': CONV_3X3,
'avg_pool_3x3': AVG_POOL_3X3 'avg_pool_3x3': AVG_POOL_3X3
} }
m = re.match(r'\|(.*)~0\|\+\|(.*)~0\|(.*)~1\|\+\|(.*)~0\|(.*)~1\|(.*)~2\|', arch_str) m: re.Match = re.match(r'\|(.*)~0\|\+\|(.*)~0\|(.*)~1\|\+\|(.*)~0\|(.*)~1\|(.*)~2\|', arch_str) # type: ignore
return { return {
'0_1': mp[m.group(1)], '0_1': mp[m.group(1)],
'0_2': mp[m.group(2)], '0_2': mp[m.group(2)],
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField from playhouse.sqlite_ext import JSONField
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools import functools
from peewee import fn from peewee import fn
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .constants import * from .constants import *
from .model import NdsTrialConfig, NdsTrialStats, NdsIntermediateStats from .model import NdsTrialConfig, NdsTrialStats, NdsIntermediateStats
from .query import query_nds_trial_stats from .query import query_nds_trial_stats
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
NONE = 'none' NONE = 'none'
SKIP_CONNECT = 'skip_connect' SKIP_CONNECT = 'skip_connect'
AVG_POOL_3X3 = 'avg_pool_3x3' AVG_POOL_3X3 = 'avg_pool_3x3'
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json import json
import argparse import argparse
import os import os
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField from playhouse.sqlite_ext import JSONField
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment