Unverified Commit 18962129 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add license header and typehints for NAS (#4774)

parent 8c2f717d
......@@ -19,5 +19,5 @@ scikit-learn >= 0.24.1
scipy < 1.8 ; python_version < "3.8"
scipy ; python_version >= "3.8"
typeguard
typing_extensions >= 4.0.0 ; python_version < "3.8"
typing_extensions >= 4.0.0
websockets >= 10.1
......@@ -13,7 +13,7 @@ import sys
import types
import warnings
from io import IOBase
from typing import Any, Dict, List, Optional, TypeVar, Union
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast, Generic
import cloudpickle # use cloudpickle as backend for unserializable types and instances
import json_tricks # use json_tricks as serializer backend
......@@ -115,7 +115,7 @@ def is_wrapped_with_trace(cls_or_func: Any) -> bool:
)
class SerializableObject(Traceable):
class SerializableObject(Generic[T], Traceable):
"""
Serializable object is a wrapper of existing python objects, that supports dump and load easily.
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
......@@ -147,7 +147,7 @@ class SerializableObject(Traceable):
# Reinitialize
return trace(self.trace_symbol)(*self.trace_args, **self.trace_kwargs)
return self
return cast(T, self)
@property
def trace_symbol(self) -> Any:
......@@ -187,7 +187,7 @@ class SerializableObject(Traceable):
')'
def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> Any:
def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> T:
# If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
# Make obj complying with the interface of traceable, though we cannot change its base class.
obj.__dict__.update(_nni_symbol=symbol, _nni_args=args, _nni_kwargs=kwargs)
......@@ -233,11 +233,11 @@ def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T:
else:
# sometimes create_wrapper is mandatory, e.g., for built-in types like list/int.
# but I don't want to check here because it's unreliable.
wrapper = type('wrapper', (Traceable, cls), attributes)
return wrapper
wrapper = type('wrapper', (Traceable, cast(Type, cls)), attributes)
return cast(T, wrapper)
def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = False) -> Union[T, Traceable]:
def trace(cls_or_func: T = cast(T, None), *, kw_only: bool = True, inheritable: bool = False) -> T:
"""
Annotate a function or a class if you want to preserve where it comes from.
This is usually used in the following scenarios:
......@@ -283,7 +283,7 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa
# Might be changed in future.
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls_or_func
return cast(T, cls_or_func)
def wrap(cls_or_func):
# already annotated, do nothing
......@@ -301,20 +301,22 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa
# if we're being called as @trace()
if cls_or_func is None:
return wrap
return wrap # type: ignore
# if we are called without parentheses
return wrap(cls_or_func)
return wrap(cls_or_func) # type: ignore
def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_size_limit: int = 4096,
allow_nan: bool = True, **json_tricks_kwargs) -> Union[str, bytes]:
allow_nan: bool = True, **json_tricks_kwargs) -> str:
"""
Convert a nested data structure to a json string. Save to file if fp is specified.
Use json-tricks as main backend. For unhandled cases in json-tricks, use cloudpickle.
The serializer is not designed for long-term storage use, but rather to copy data between processes.
The format is also subject to change between NNI releases.
To compress the payload, please use :func:`dump_bytes`.
Parameters
----------
obj : any
......@@ -334,6 +336,39 @@ def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_s
Normally str. Sometimes bytes (if compressed).
"""
if json_tricks_kwargs.get('compression') is not None:
raise ValueError('If you meant to compress the dumped payload, please use `dump_bytes`.')
result = _dump(
obj=obj,
fp=fp,
use_trace=use_trace,
pickle_size_limit=pickle_size_limit,
allow_nan=allow_nan,
**json_tricks_kwargs)
return cast(str, result)
def dump_bytes(obj: Any, fp: Optional[Any] = None, *, compression: int = cast(int, None),
use_trace: bool = True, pickle_size_limit: int = 4096,
allow_nan: bool = True, **json_tricks_kwargs) -> bytes:
"""
Same as :func:`dump`, but to comporess payload, with `compression <https://json-tricks.readthedocs.io/en/stable/#dump>`__.
"""
if compression is None:
raise ValueError('compression must be set.')
result = _dump(
obj=obj,
fp=fp,
compression=compression,
use_trace=use_trace,
pickle_size_limit=pickle_size_limit,
allow_nan=allow_nan,
**json_tricks_kwargs)
return cast(bytes, result)
def _dump(*, obj: Any, fp: Optional[Any], use_trace: bool, pickle_size_limit: int,
allow_nan: bool, **json_tricks_kwargs) -> Union[str, bytes]:
encoders = [
# we don't need to check for dependency as many of those have already been required by NNI
json_tricks.pathlib_encode, # pathlib is a required dependency for NNI
......@@ -456,7 +491,7 @@ def _trace_cls(base, kw_only, call_super=True, inheritable=False):
raise TypeError(f"{base} has a superclass already decorated with trace, and it's using a customized metaclass {type(base)}. "
"Please either use the default metaclass, or remove trace from the super-class.")
class wrapper(SerializableObject, base, metaclass=metaclass):
class wrapper(SerializableObject, base, metaclass=metaclass): # type: ignore
def __init__(self, *args, **kwargs):
# store a copy of initial parameters
args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True)
......@@ -528,7 +563,8 @@ def _trace_func(func, kw_only):
# and thus not possible to restore the trace parameters after dump and reload.
# this is a known limitation.
new_type = _make_class_traceable(type(res), True)
res = new_type(res) # re-creating the object
# re-creating the object
res = new_type(res) # type: ignore
res = inject_trace_info(res, func, args, kwargs)
else:
raise TypeError(f'Try to add trace info to {res}, but the type "{type(res)}" is unknown. '
......@@ -750,7 +786,7 @@ def import_cls_or_func_from_hybrid_name(s: str) -> Any:
return _import_cls_or_func_from_name(s)
def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> str:
def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> Dict[str, str]:
if not isinstance(cls_or_func, type) and not _is_function(cls_or_func):
# not a function or class, continue
return cls_or_func
......@@ -762,8 +798,7 @@ def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False,
def _json_tricks_func_or_cls_decode(s: Dict[str, Any]) -> Any:
if isinstance(s, dict) and '__nni_type__' in s:
s = s['__nni_type__']
return import_cls_or_func_from_hybrid_name(s)
return import_cls_or_func_from_hybrid_name(s['__nni_type__'])
return s
......@@ -815,8 +850,7 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si
def _json_tricks_any_object_decode(obj: Dict[str, Any]) -> Any:
if isinstance(obj, dict) and '__nni_obj__' in obj:
obj = obj['__nni_obj__']
b = base64.b64decode(obj)
b = base64.b64decode(obj['__nni_obj__'])
return _wrapped_cloudpickle_loads(b)
return obj
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .utils import load_benchmark, download_benchmark
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
if __name__ == '__main__':
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .constants import INPUT, OUTPUT, CONV3X3_BN_RELU, CONV1X1_BN_RELU, MAXPOOL3X3
from .model import Nb101TrialStats, Nb101IntermediateStats, Nb101TrialConfig
from .query import query_nb101_trial_stats
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
INPUT = 'input'
OUTPUT = 'output'
CONV3X3_BN_RELU = 'conv3x3-bn-relu'
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
from tqdm import tqdm
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import hashlib
import numpy as np
......@@ -10,7 +13,7 @@ def _labeling_from_architecture(architecture, vertices):
def _adjancency_matrix_from_architecture(architecture, vertices):
matrix = np.zeros((vertices, vertices), dtype=np.bool)
matrix = np.zeros((vertices, vertices), dtype=np.bool) # type: ignore
for i in range(1, vertices):
for k in architecture['input{}'.format(i)]:
matrix[k, i] = 1
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools
from peewee import fn
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3
from .model import Nb201TrialStats, Nb201IntermediateStats, Nb201TrialConfig
from .query import query_nb201_trial_stats
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
NONE = 'none'
SKIP_CONNECT = 'skip_connect'
CONV_1X1 = 'conv_1x1'
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import re
......@@ -17,7 +20,7 @@ def parse_arch_str(arch_str):
'nor_conv_3x3': CONV_3X3,
'avg_pool_3x3': AVG_POOL_3X3
}
m = re.match(r'\|(.*)~0\|\+\|(.*)~0\|(.*)~1\|\+\|(.*)~0\|(.*)~1\|(.*)~2\|', arch_str)
m: re.Match = re.match(r'\|(.*)~0\|\+\|(.*)~0\|(.*)~1\|\+\|(.*)~0\|(.*)~1\|(.*)~2\|', arch_str) # type: ignore
return {
'0_1': mp[m.group(1)],
'0_2': mp[m.group(2)],
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools
from peewee import fn
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .constants import *
from .model import NdsTrialConfig, NdsTrialStats, NdsIntermediateStats
from .query import query_nds_trial_stats
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
NONE = 'none'
SKIP_CONNECT = 'skip_connect'
AVG_POOL_3X3 = 'avg_pool_3x3'
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import argparse
import os
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment