Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
18962129
Unverified
Commit
18962129
authored
Apr 25, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 25, 2022
Browse files
Add license header and typehints for NAS (#4774)
parent
8c2f717d
Changes
96
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
109 additions
and
21 deletions
+109
-21
dependencies/required.txt
dependencies/required.txt
+1
-1
nni/common/serializer.py
nni/common/serializer.py
+52
-18
nni/nas/benchmarks/__init__.py
nni/nas/benchmarks/__init__.py
+3
-0
nni/nas/benchmarks/constants.py
nni/nas/benchmarks/constants.py
+3
-0
nni/nas/benchmarks/download.py
nni/nas/benchmarks/download.py
+3
-0
nni/nas/benchmarks/nasbench101/__init__.py
nni/nas/benchmarks/nasbench101/__init__.py
+3
-0
nni/nas/benchmarks/nasbench101/constants.py
nni/nas/benchmarks/nasbench101/constants.py
+3
-0
nni/nas/benchmarks/nasbench101/db_gen.py
nni/nas/benchmarks/nasbench101/db_gen.py
+3
-0
nni/nas/benchmarks/nasbench101/graph_util.py
nni/nas/benchmarks/nasbench101/graph_util.py
+4
-1
nni/nas/benchmarks/nasbench101/model.py
nni/nas/benchmarks/nasbench101/model.py
+3
-0
nni/nas/benchmarks/nasbench101/query.py
nni/nas/benchmarks/nasbench101/query.py
+3
-0
nni/nas/benchmarks/nasbench201/__init__.py
nni/nas/benchmarks/nasbench201/__init__.py
+3
-0
nni/nas/benchmarks/nasbench201/constants.py
nni/nas/benchmarks/nasbench201/constants.py
+3
-0
nni/nas/benchmarks/nasbench201/db_gen.py
nni/nas/benchmarks/nasbench201/db_gen.py
+4
-1
nni/nas/benchmarks/nasbench201/model.py
nni/nas/benchmarks/nasbench201/model.py
+3
-0
nni/nas/benchmarks/nasbench201/query.py
nni/nas/benchmarks/nasbench201/query.py
+3
-0
nni/nas/benchmarks/nds/__init__.py
nni/nas/benchmarks/nds/__init__.py
+3
-0
nni/nas/benchmarks/nds/constants.py
nni/nas/benchmarks/nds/constants.py
+3
-0
nni/nas/benchmarks/nds/db_gen.py
nni/nas/benchmarks/nds/db_gen.py
+3
-0
nni/nas/benchmarks/nds/model.py
nni/nas/benchmarks/nds/model.py
+3
-0
No files found.
dependencies/required.txt
View file @
18962129
...
...
@@ -19,5 +19,5 @@ scikit-learn >= 0.24.1
scipy < 1.8 ; python_version < "3.8"
scipy ; python_version >= "3.8"
typeguard
typing_extensions >= 4.0.0
; python_version < "3.8"
typing_extensions >= 4.0.0
websockets >= 10.1
nni/common/serializer.py
View file @
18962129
...
...
@@ -13,7 +13,7 @@ import sys
import
types
import
warnings
from
io
import
IOBase
from
typing
import
Any
,
Dict
,
List
,
Optional
,
TypeVar
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
TypeVar
,
Union
,
cast
,
Generic
import
cloudpickle
# use cloudpickle as backend for unserializable types and instances
import
json_tricks
# use json_tricks as serializer backend
...
...
@@ -115,7 +115,7 @@ def is_wrapped_with_trace(cls_or_func: Any) -> bool:
)
class
SerializableObject
(
Traceable
):
class
SerializableObject
(
Generic
[
T
],
Traceable
):
"""
Serializable object is a wrapper of existing python objects, that supports dump and load easily.
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
...
...
@@ -147,7 +147,7 @@ class SerializableObject(Traceable):
# Reinitialize
return
trace
(
self
.
trace_symbol
)(
*
self
.
trace_args
,
**
self
.
trace_kwargs
)
return
self
return
cast
(
T
,
self
)
@
property
def
trace_symbol
(
self
)
->
Any
:
...
...
@@ -187,7 +187,7 @@ class SerializableObject(Traceable):
')'
def
inject_trace_info
(
obj
:
Any
,
symbol
:
T
,
args
:
List
[
Any
],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
def
inject_trace_info
(
obj
:
Any
,
symbol
:
T
,
args
:
List
[
Any
],
kwargs
:
Dict
[
str
,
Any
])
->
T
:
# If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
# Make obj complying with the interface of traceable, though we cannot change its base class.
obj
.
__dict__
.
update
(
_nni_symbol
=
symbol
,
_nni_args
=
args
,
_nni_kwargs
=
kwargs
)
...
...
@@ -233,11 +233,11 @@ def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T:
else
:
# sometimes create_wrapper is mandatory, e.g., for built-in types like list/int.
# but I don't want to check here because it's unreliable.
wrapper
=
type
(
'wrapper'
,
(
Traceable
,
cls
),
attributes
)
return
wrapper
wrapper
=
type
(
'wrapper'
,
(
Traceable
,
cast
(
Type
,
cls
)
)
,
attributes
)
return
cast
(
T
,
wrapper
)
def
trace
(
cls_or_func
:
T
=
None
,
*
,
kw_only
:
bool
=
True
,
inheritable
:
bool
=
False
)
->
Union
[
T
,
Traceable
]
:
def
trace
(
cls_or_func
:
T
=
cast
(
T
,
None
)
,
*
,
kw_only
:
bool
=
True
,
inheritable
:
bool
=
False
)
->
T
:
"""
Annotate a function or a class if you want to preserve where it comes from.
This is usually used in the following scenarios:
...
...
@@ -283,7 +283,7 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa
# Might be changed in future.
nni_trace_flag
=
os
.
environ
.
get
(
'NNI_TRACE_FLAG'
,
''
)
if
nni_trace_flag
.
lower
()
==
'disable'
:
return
cls_or_func
return
cast
(
T
,
cls_or_func
)
def
wrap
(
cls_or_func
):
# already annotated, do nothing
...
...
@@ -301,20 +301,22 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa
# if we're being called as @trace()
if
cls_or_func
is
None
:
return
wrap
return
wrap
# type: ignore
# if we are called without parentheses
return
wrap
(
cls_or_func
)
return
wrap
(
cls_or_func
)
# type: ignore
def
dump
(
obj
:
Any
,
fp
:
Optional
[
Any
]
=
None
,
*
,
use_trace
:
bool
=
True
,
pickle_size_limit
:
int
=
4096
,
allow_nan
:
bool
=
True
,
**
json_tricks_kwargs
)
->
Union
[
str
,
bytes
]
:
allow_nan
:
bool
=
True
,
**
json_tricks_kwargs
)
->
str
:
"""
Convert a nested data structure to a json string. Save to file if fp is specified.
Use json-tricks as main backend. For unhandled cases in json-tricks, use cloudpickle.
The serializer is not designed for long-term storage use, but rather to copy data between processes.
The format is also subject to change between NNI releases.
To compress the payload, please use :func:`dump_bytes`.
Parameters
----------
obj : any
...
...
@@ -334,6 +336,39 @@ def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_s
Normally str. Sometimes bytes (if compressed).
"""
if
json_tricks_kwargs
.
get
(
'compression'
)
is
not
None
:
raise
ValueError
(
'If you meant to compress the dumped payload, please use `dump_bytes`.'
)
result
=
_dump
(
obj
=
obj
,
fp
=
fp
,
use_trace
=
use_trace
,
pickle_size_limit
=
pickle_size_limit
,
allow_nan
=
allow_nan
,
**
json_tricks_kwargs
)
return
cast
(
str
,
result
)
def
dump_bytes
(
obj
:
Any
,
fp
:
Optional
[
Any
]
=
None
,
*
,
compression
:
int
=
cast
(
int
,
None
),
use_trace
:
bool
=
True
,
pickle_size_limit
:
int
=
4096
,
allow_nan
:
bool
=
True
,
**
json_tricks_kwargs
)
->
bytes
:
"""
Same as :func:`dump`, but to comporess payload, with `compression <https://json-tricks.readthedocs.io/en/stable/#dump>`__.
"""
if
compression
is
None
:
raise
ValueError
(
'compression must be set.'
)
result
=
_dump
(
obj
=
obj
,
fp
=
fp
,
compression
=
compression
,
use_trace
=
use_trace
,
pickle_size_limit
=
pickle_size_limit
,
allow_nan
=
allow_nan
,
**
json_tricks_kwargs
)
return
cast
(
bytes
,
result
)
def
_dump
(
*
,
obj
:
Any
,
fp
:
Optional
[
Any
],
use_trace
:
bool
,
pickle_size_limit
:
int
,
allow_nan
:
bool
,
**
json_tricks_kwargs
)
->
Union
[
str
,
bytes
]:
encoders
=
[
# we don't need to check for dependency as many of those have already been required by NNI
json_tricks
.
pathlib_encode
,
# pathlib is a required dependency for NNI
...
...
@@ -456,7 +491,7 @@ def _trace_cls(base, kw_only, call_super=True, inheritable=False):
raise
TypeError
(
f
"
{
base
}
has a superclass already decorated with trace, and it's using a customized metaclass
{
type
(
base
)
}
. "
"Please either use the default metaclass, or remove trace from the super-class."
)
class
wrapper
(
SerializableObject
,
base
,
metaclass
=
metaclass
):
class
wrapper
(
SerializableObject
,
base
,
metaclass
=
metaclass
):
# type: ignore
def
__init__
(
self
,
*
args
,
**
kwargs
):
# store a copy of initial parameters
args
,
kwargs
=
_formulate_arguments
(
base
.
__init__
,
args
,
kwargs
,
kw_only
,
is_class_init
=
True
)
...
...
@@ -528,7 +563,8 @@ def _trace_func(func, kw_only):
# and thus not possible to restore the trace parameters after dump and reload.
# this is a known limitation.
new_type
=
_make_class_traceable
(
type
(
res
),
True
)
res
=
new_type
(
res
)
# re-creating the object
# re-creating the object
res
=
new_type
(
res
)
# type: ignore
res
=
inject_trace_info
(
res
,
func
,
args
,
kwargs
)
else
:
raise
TypeError
(
f
'Try to add trace info to
{
res
}
, but the type "
{
type
(
res
)
}
" is unknown. '
...
...
@@ -750,7 +786,7 @@ def import_cls_or_func_from_hybrid_name(s: str) -> Any:
return
_import_cls_or_func_from_name
(
s
)
def
_json_tricks_func_or_cls_encode
(
cls_or_func
:
Any
,
primitives
:
bool
=
False
,
pickle_size_limit
:
int
=
4096
)
->
str
:
def
_json_tricks_func_or_cls_encode
(
cls_or_func
:
Any
,
primitives
:
bool
=
False
,
pickle_size_limit
:
int
=
4096
)
->
Dict
[
str
,
str
]
:
if
not
isinstance
(
cls_or_func
,
type
)
and
not
_is_function
(
cls_or_func
):
# not a function or class, continue
return
cls_or_func
...
...
@@ -762,8 +798,7 @@ def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False,
def
_json_tricks_func_or_cls_decode
(
s
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
s
,
dict
)
and
'__nni_type__'
in
s
:
s
=
s
[
'__nni_type__'
]
return
import_cls_or_func_from_hybrid_name
(
s
)
return
import_cls_or_func_from_hybrid_name
(
s
[
'__nni_type__'
])
return
s
...
...
@@ -815,8 +850,7 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si
def
_json_tricks_any_object_decode
(
obj
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
obj
,
dict
)
and
'__nni_obj__'
in
obj
:
obj
=
obj
[
'__nni_obj__'
]
b
=
base64
.
b64decode
(
obj
)
b
=
base64
.
b64decode
(
obj
[
'__nni_obj__'
])
return
_wrapped_cloudpickle_loads
(
b
)
return
obj
...
...
nni/nas/benchmarks/__init__.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.utils
import
load_benchmark
,
download_benchmark
nni/nas/benchmarks/constants.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
...
...
nni/nas/benchmarks/download.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
argparse
if
__name__
==
'__main__'
:
...
...
nni/nas/benchmarks/nasbench101/__init__.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.constants
import
INPUT
,
OUTPUT
,
CONV3X3_BN_RELU
,
CONV1X1_BN_RELU
,
MAXPOOL3X3
from
.model
import
Nb101TrialStats
,
Nb101IntermediateStats
,
Nb101TrialConfig
from
.query
import
query_nb101_trial_stats
nni/nas/benchmarks/nasbench101/constants.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
INPUT
=
'input'
OUTPUT
=
'output'
CONV3X3_BN_RELU
=
'conv3x3-bn-relu'
...
...
nni/nas/benchmarks/nasbench101/db_gen.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
argparse
from
tqdm
import
tqdm
...
...
nni/nas/benchmarks/nasbench101/graph_util.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
hashlib
import
numpy
as
np
...
...
@@ -10,7 +13,7 @@ def _labeling_from_architecture(architecture, vertices):
def
_adjancency_matrix_from_architecture
(
architecture
,
vertices
):
matrix
=
np
.
zeros
((
vertices
,
vertices
),
dtype
=
np
.
bool
)
matrix
=
np
.
zeros
((
vertices
,
vertices
),
dtype
=
np
.
bool
)
# type: ignore
for
i
in
range
(
1
,
vertices
):
for
k
in
architecture
[
'input{}'
.
format
(
i
)]:
matrix
[
k
,
i
]
=
1
...
...
nni/nas/benchmarks/nasbench101/model.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
...
...
nni/nas/benchmarks/nasbench101/query.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
functools
from
peewee
import
fn
...
...
nni/nas/benchmarks/nasbench201/__init__.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.constants
import
NONE
,
SKIP_CONNECT
,
CONV_1X1
,
CONV_3X3
,
AVG_POOL_3X3
from
.model
import
Nb201TrialStats
,
Nb201IntermediateStats
,
Nb201TrialConfig
from
.query
import
query_nb201_trial_stats
nni/nas/benchmarks/nasbench201/constants.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
NONE
=
'none'
SKIP_CONNECT
=
'skip_connect'
CONV_1X1
=
'conv_1x1'
...
...
nni/nas/benchmarks/nasbench201/db_gen.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
argparse
import
re
...
...
@@ -17,7 +20,7 @@ def parse_arch_str(arch_str):
'nor_conv_3x3'
:
CONV_3X3
,
'avg_pool_3x3'
:
AVG_POOL_3X3
}
m
=
re
.
match
(
r
'\|(.*)~0\|\+\|(.*)~0\|(.*)~1\|\+\|(.*)~0\|(.*)~1\|(.*)~2\|'
,
arch_str
)
m
:
re
.
Match
=
re
.
match
(
r
'\|(.*)~0\|\+\|(.*)~0\|(.*)~1\|\+\|(.*)~0\|(.*)~1\|(.*)~2\|'
,
arch_str
)
# type: ignore
return
{
'0_1'
:
mp
[
m
.
group
(
1
)],
'0_2'
:
mp
[
m
.
group
(
2
)],
...
...
nni/nas/benchmarks/nasbench201/model.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
...
...
nni/nas/benchmarks/nasbench201/query.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
functools
from
peewee
import
fn
...
...
nni/nas/benchmarks/nds/__init__.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.constants
import
*
from
.model
import
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
from
.query
import
query_nds_trial_stats
nni/nas/benchmarks/nds/constants.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
NONE
=
'none'
SKIP_CONNECT
=
'skip_connect'
AVG_POOL_3X3
=
'avg_pool_3x3'
...
...
nni/nas/benchmarks/nds/db_gen.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
argparse
import
os
...
...
nni/nas/benchmarks/nds/model.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment