Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
18962129
"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "37c3c4a9db71c37ee418cd2affadf78eff1cbf30"
Unverified
Commit
18962129
authored
Apr 25, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 25, 2022
Browse files
Add license header and typehints for NAS (#4774)
parent
8c2f717d
Changes
96
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
109 additions
and
21 deletions
+109
-21
dependencies/required.txt
dependencies/required.txt
+1
-1
nni/common/serializer.py
nni/common/serializer.py
+52
-18
nni/nas/benchmarks/__init__.py
nni/nas/benchmarks/__init__.py
+3
-0
nni/nas/benchmarks/constants.py
nni/nas/benchmarks/constants.py
+3
-0
nni/nas/benchmarks/download.py
nni/nas/benchmarks/download.py
+3
-0
nni/nas/benchmarks/nasbench101/__init__.py
nni/nas/benchmarks/nasbench101/__init__.py
+3
-0
nni/nas/benchmarks/nasbench101/constants.py
nni/nas/benchmarks/nasbench101/constants.py
+3
-0
nni/nas/benchmarks/nasbench101/db_gen.py
nni/nas/benchmarks/nasbench101/db_gen.py
+3
-0
nni/nas/benchmarks/nasbench101/graph_util.py
nni/nas/benchmarks/nasbench101/graph_util.py
+4
-1
nni/nas/benchmarks/nasbench101/model.py
nni/nas/benchmarks/nasbench101/model.py
+3
-0
nni/nas/benchmarks/nasbench101/query.py
nni/nas/benchmarks/nasbench101/query.py
+3
-0
nni/nas/benchmarks/nasbench201/__init__.py
nni/nas/benchmarks/nasbench201/__init__.py
+3
-0
nni/nas/benchmarks/nasbench201/constants.py
nni/nas/benchmarks/nasbench201/constants.py
+3
-0
nni/nas/benchmarks/nasbench201/db_gen.py
nni/nas/benchmarks/nasbench201/db_gen.py
+4
-1
nni/nas/benchmarks/nasbench201/model.py
nni/nas/benchmarks/nasbench201/model.py
+3
-0
nni/nas/benchmarks/nasbench201/query.py
nni/nas/benchmarks/nasbench201/query.py
+3
-0
nni/nas/benchmarks/nds/__init__.py
nni/nas/benchmarks/nds/__init__.py
+3
-0
nni/nas/benchmarks/nds/constants.py
nni/nas/benchmarks/nds/constants.py
+3
-0
nni/nas/benchmarks/nds/db_gen.py
nni/nas/benchmarks/nds/db_gen.py
+3
-0
nni/nas/benchmarks/nds/model.py
nni/nas/benchmarks/nds/model.py
+3
-0
No files found.
dependencies/required.txt
View file @
18962129
...
...
@@ -19,5 +19,5 @@ scikit-learn >= 0.24.1
scipy < 1.8 ; python_version < "3.8"
scipy ; python_version >= "3.8"
typeguard
typing_extensions >= 4.0.0
; python_version < "3.8"
typing_extensions >= 4.0.0
websockets >= 10.1
nni/common/serializer.py
View file @
18962129
...
...
@@ -13,7 +13,7 @@ import sys
import
types
import
warnings
from
io
import
IOBase
from
typing
import
Any
,
Dict
,
List
,
Optional
,
TypeVar
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
TypeVar
,
Union
,
cast
,
Generic
import
cloudpickle
# use cloudpickle as backend for unserializable types and instances
import
json_tricks
# use json_tricks as serializer backend
...
...
@@ -115,7 +115,7 @@ def is_wrapped_with_trace(cls_or_func: Any) -> bool:
)
class
SerializableObject
(
Traceable
):
class
SerializableObject
(
Generic
[
T
],
Traceable
):
"""
Serializable object is a wrapper of existing python objects, that supports dump and load easily.
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
...
...
@@ -147,7 +147,7 @@ class SerializableObject(Traceable):
# Reinitialize
return
trace
(
self
.
trace_symbol
)(
*
self
.
trace_args
,
**
self
.
trace_kwargs
)
return
self
return
cast
(
T
,
self
)
@
property
def
trace_symbol
(
self
)
->
Any
:
...
...
@@ -187,7 +187,7 @@ class SerializableObject(Traceable):
')'
def
inject_trace_info
(
obj
:
Any
,
symbol
:
T
,
args
:
List
[
Any
],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
def
inject_trace_info
(
obj
:
Any
,
symbol
:
T
,
args
:
List
[
Any
],
kwargs
:
Dict
[
str
,
Any
])
->
T
:
# If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
# Make obj complying with the interface of traceable, though we cannot change its base class.
obj
.
__dict__
.
update
(
_nni_symbol
=
symbol
,
_nni_args
=
args
,
_nni_kwargs
=
kwargs
)
...
...
@@ -233,11 +233,11 @@ def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T:
else
:
# sometimes create_wrapper is mandatory, e.g., for built-in types like list/int.
# but I don't want to check here because it's unreliable.
wrapper
=
type
(
'wrapper'
,
(
Traceable
,
cls
),
attributes
)
return
wrapper
wrapper
=
type
(
'wrapper'
,
(
Traceable
,
cast
(
Type
,
cls
)
)
,
attributes
)
return
cast
(
T
,
wrapper
)
def
trace
(
cls_or_func
:
T
=
None
,
*
,
kw_only
:
bool
=
True
,
inheritable
:
bool
=
False
)
->
Union
[
T
,
Traceable
]
:
def
trace
(
cls_or_func
:
T
=
cast
(
T
,
None
)
,
*
,
kw_only
:
bool
=
True
,
inheritable
:
bool
=
False
)
->
T
:
"""
Annotate a function or a class if you want to preserve where it comes from.
This is usually used in the following scenarios:
...
...
@@ -283,7 +283,7 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa
# Might be changed in future.
nni_trace_flag
=
os
.
environ
.
get
(
'NNI_TRACE_FLAG'
,
''
)
if
nni_trace_flag
.
lower
()
==
'disable'
:
return
cls_or_func
return
cast
(
T
,
cls_or_func
)
def
wrap
(
cls_or_func
):
# already annotated, do nothing
...
...
@@ -301,20 +301,22 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa
# if we're being called as @trace()
if
cls_or_func
is
None
:
return
wrap
return
wrap
# type: ignore
# if we are called without parentheses
return
wrap
(
cls_or_func
)
return
wrap
(
cls_or_func
)
# type: ignore
def
dump
(
obj
:
Any
,
fp
:
Optional
[
Any
]
=
None
,
*
,
use_trace
:
bool
=
True
,
pickle_size_limit
:
int
=
4096
,
allow_nan
:
bool
=
True
,
**
json_tricks_kwargs
)
->
Union
[
str
,
bytes
]
:
allow_nan
:
bool
=
True
,
**
json_tricks_kwargs
)
->
str
:
"""
Convert a nested data structure to a json string. Save to file if fp is specified.
Use json-tricks as main backend. For unhandled cases in json-tricks, use cloudpickle.
The serializer is not designed for long-term storage use, but rather to copy data between processes.
The format is also subject to change between NNI releases.
To compress the payload, please use :func:`dump_bytes`.
Parameters
----------
obj : any
...
...
@@ -334,6 +336,39 @@ def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_s
Normally str. Sometimes bytes (if compressed).
"""
if
json_tricks_kwargs
.
get
(
'compression'
)
is
not
None
:
raise
ValueError
(
'If you meant to compress the dumped payload, please use `dump_bytes`.'
)
result
=
_dump
(
obj
=
obj
,
fp
=
fp
,
use_trace
=
use_trace
,
pickle_size_limit
=
pickle_size_limit
,
allow_nan
=
allow_nan
,
**
json_tricks_kwargs
)
return
cast
(
str
,
result
)
def
dump_bytes
(
obj
:
Any
,
fp
:
Optional
[
Any
]
=
None
,
*
,
compression
:
int
=
cast
(
int
,
None
),
use_trace
:
bool
=
True
,
pickle_size_limit
:
int
=
4096
,
allow_nan
:
bool
=
True
,
**
json_tricks_kwargs
)
->
bytes
:
"""
Same as :func:`dump`, but to comporess payload, with `compression <https://json-tricks.readthedocs.io/en/stable/#dump>`__.
"""
if
compression
is
None
:
raise
ValueError
(
'compression must be set.'
)
result
=
_dump
(
obj
=
obj
,
fp
=
fp
,
compression
=
compression
,
use_trace
=
use_trace
,
pickle_size_limit
=
pickle_size_limit
,
allow_nan
=
allow_nan
,
**
json_tricks_kwargs
)
return
cast
(
bytes
,
result
)
def
_dump
(
*
,
obj
:
Any
,
fp
:
Optional
[
Any
],
use_trace
:
bool
,
pickle_size_limit
:
int
,
allow_nan
:
bool
,
**
json_tricks_kwargs
)
->
Union
[
str
,
bytes
]:
encoders
=
[
# we don't need to check for dependency as many of those have already been required by NNI
json_tricks
.
pathlib_encode
,
# pathlib is a required dependency for NNI
...
...
@@ -456,7 +491,7 @@ def _trace_cls(base, kw_only, call_super=True, inheritable=False):
raise
TypeError
(
f
"
{
base
}
has a superclass already decorated with trace, and it's using a customized metaclass
{
type
(
base
)
}
. "
"Please either use the default metaclass, or remove trace from the super-class."
)
class
wrapper
(
SerializableObject
,
base
,
metaclass
=
metaclass
):
class
wrapper
(
SerializableObject
,
base
,
metaclass
=
metaclass
):
# type: ignore
def
__init__
(
self
,
*
args
,
**
kwargs
):
# store a copy of initial parameters
args
,
kwargs
=
_formulate_arguments
(
base
.
__init__
,
args
,
kwargs
,
kw_only
,
is_class_init
=
True
)
...
...
@@ -528,7 +563,8 @@ def _trace_func(func, kw_only):
# and thus not possible to restore the trace parameters after dump and reload.
# this is a known limitation.
new_type
=
_make_class_traceable
(
type
(
res
),
True
)
res
=
new_type
(
res
)
# re-creating the object
# re-creating the object
res
=
new_type
(
res
)
# type: ignore
res
=
inject_trace_info
(
res
,
func
,
args
,
kwargs
)
else
:
raise
TypeError
(
f
'Try to add trace info to
{
res
}
, but the type "
{
type
(
res
)
}
" is unknown. '
...
...
@@ -750,7 +786,7 @@ def import_cls_or_func_from_hybrid_name(s: str) -> Any:
return
_import_cls_or_func_from_name
(
s
)
def
_json_tricks_func_or_cls_encode
(
cls_or_func
:
Any
,
primitives
:
bool
=
False
,
pickle_size_limit
:
int
=
4096
)
->
str
:
def
_json_tricks_func_or_cls_encode
(
cls_or_func
:
Any
,
primitives
:
bool
=
False
,
pickle_size_limit
:
int
=
4096
)
->
Dict
[
str
,
str
]
:
if
not
isinstance
(
cls_or_func
,
type
)
and
not
_is_function
(
cls_or_func
):
# not a function or class, continue
return
cls_or_func
...
...
@@ -762,8 +798,7 @@ def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False,
def
_json_tricks_func_or_cls_decode
(
s
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
s
,
dict
)
and
'__nni_type__'
in
s
:
s
=
s
[
'__nni_type__'
]
return
import_cls_or_func_from_hybrid_name
(
s
)
return
import_cls_or_func_from_hybrid_name
(
s
[
'__nni_type__'
])
return
s
...
...
@@ -815,8 +850,7 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si
def
_json_tricks_any_object_decode
(
obj
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
obj
,
dict
)
and
'__nni_obj__'
in
obj
:
obj
=
obj
[
'__nni_obj__'
]
b
=
base64
.
b64decode
(
obj
)
b
=
base64
.
b64decode
(
obj
[
'__nni_obj__'
])
return
_wrapped_cloudpickle_loads
(
b
)
return
obj
...
...
nni/nas/benchmarks/__init__.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.utils
import
load_benchmark
,
download_benchmark
nni/nas/benchmarks/constants.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
...
...
nni/nas/benchmarks/download.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
argparse
if
__name__
==
'__main__'
:
...
...
nni/nas/benchmarks/nasbench101/__init__.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.constants
import
INPUT
,
OUTPUT
,
CONV3X3_BN_RELU
,
CONV1X1_BN_RELU
,
MAXPOOL3X3
from
.model
import
Nb101TrialStats
,
Nb101IntermediateStats
,
Nb101TrialConfig
from
.query
import
query_nb101_trial_stats
nni/nas/benchmarks/nasbench101/constants.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
INPUT
=
'input'
OUTPUT
=
'output'
CONV3X3_BN_RELU
=
'conv3x3-bn-relu'
...
...
nni/nas/benchmarks/nasbench101/db_gen.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
argparse
from
tqdm
import
tqdm
...
...
nni/nas/benchmarks/nasbench101/graph_util.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
hashlib
import
numpy
as
np
...
...
@@ -10,7 +13,7 @@ def _labeling_from_architecture(architecture, vertices):
def
_adjancency_matrix_from_architecture
(
architecture
,
vertices
):
matrix
=
np
.
zeros
((
vertices
,
vertices
),
dtype
=
np
.
bool
)
matrix
=
np
.
zeros
((
vertices
,
vertices
),
dtype
=
np
.
bool
)
# type: ignore
for
i
in
range
(
1
,
vertices
):
for
k
in
architecture
[
'input{}'
.
format
(
i
)]:
matrix
[
k
,
i
]
=
1
...
...
nni/nas/benchmarks/nasbench101/model.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
...
...
nni/nas/benchmarks/nasbench101/query.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
functools
from
peewee
import
fn
...
...
nni/nas/benchmarks/nasbench201/__init__.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.constants
import
NONE
,
SKIP_CONNECT
,
CONV_1X1
,
CONV_3X3
,
AVG_POOL_3X3
from
.model
import
Nb201TrialStats
,
Nb201IntermediateStats
,
Nb201TrialConfig
from
.query
import
query_nb201_trial_stats
nni/nas/benchmarks/nasbench201/constants.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
NONE
=
'none'
SKIP_CONNECT
=
'skip_connect'
CONV_1X1
=
'conv_1x1'
...
...
nni/nas/benchmarks/nasbench201/db_gen.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
argparse
import
re
...
...
@@ -17,7 +20,7 @@ def parse_arch_str(arch_str):
'nor_conv_3x3'
:
CONV_3X3
,
'avg_pool_3x3'
:
AVG_POOL_3X3
}
m
=
re
.
match
(
r
'\|(.*)~0\|\+\|(.*)~0\|(.*)~1\|\+\|(.*)~0\|(.*)~1\|(.*)~2\|'
,
arch_str
)
m
:
re
.
Match
=
re
.
match
(
r
'\|(.*)~0\|\+\|(.*)~0\|(.*)~1\|\+\|(.*)~0\|(.*)~1\|(.*)~2\|'
,
arch_str
)
# type: ignore
return
{
'0_1'
:
mp
[
m
.
group
(
1
)],
'0_2'
:
mp
[
m
.
group
(
2
)],
...
...
nni/nas/benchmarks/nasbench201/model.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
...
...
nni/nas/benchmarks/nasbench201/query.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
functools
from
peewee
import
fn
...
...
nni/nas/benchmarks/nds/__init__.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.constants
import
*
from
.model
import
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
from
.query
import
query_nds_trial_stats
nni/nas/benchmarks/nds/constants.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
NONE
=
'none'
SKIP_CONNECT
=
'skip_connect'
AVG_POOL_3X3
=
'avg_pool_3x3'
...
...
nni/nas/benchmarks/nds/db_gen.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
argparse
import
os
...
...
nni/nas/benchmarks/nds/model.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment