Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d6791c2b
Commit
d6791c2b
authored
Nov 05, 2020
by
Yuge Zhang
Browse files
Merge branch 'master' into dev-retiarii
parents
19726d4d
16dc45b1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
36 additions
and
41 deletions
+36
-41
nni/algorithms/nas/pytorch/spos/trainer.py
nni/algorithms/nas/pytorch/spos/trainer.py
+2
-0
nni/nas/pytorch/fixed.py
nni/nas/pytorch/fixed.py
+11
-5
setup.py
setup.py
+20
-1
setup_ts.py
setup_ts.py
+1
-1
ts/nni_manager/yarn.lock
ts/nni_manager/yarn.lock
+2
-34
No files found.
nni/algorithms/nas/pytorch/spos/trainer.py
View file @
d6791c2b
...
...
@@ -63,6 +63,7 @@ class SPOSSupernetTrainer(Trainer):
self
.
model
.
train
()
meters
=
AverageMeterGroup
()
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
train_loader
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
self
.
optimizer
.
zero_grad
()
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
...
...
@@ -82,6 +83,7 @@ class SPOSSupernetTrainer(Trainer):
meters
=
AverageMeterGroup
()
with
torch
.
no_grad
():
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
valid_loader
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
loss
=
self
.
loss
(
logits
,
y
)
...
...
nni/nas/pytorch/fixed.py
View file @
d6791c2b
...
...
@@ -24,11 +24,14 @@ class FixedArchitecture(Mutator):
Preloaded architecture object.
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
verbose : bool
Print log messages if set to True
"""
def
__init__
(
self
,
model
,
fixed_arc
,
strict
=
True
):
def
__init__
(
self
,
model
,
fixed_arc
,
strict
=
True
,
verbose
=
True
):
super
().
__init__
(
model
)
self
.
_fixed_arc
=
fixed_arc
self
.
verbose
=
verbose
mutable_keys
=
set
([
mutable
.
key
for
mutable
in
self
.
mutables
if
not
isinstance
(
mutable
,
MutableScope
)])
fixed_arc_keys
=
set
(
self
.
_fixed_arc
.
keys
())
...
...
@@ -99,10 +102,11 @@ class FixedArchitecture(Mutator):
if
sum
(
chosen
)
==
1
and
max
(
chosen
)
==
1
and
not
mutable
.
return_mask
:
# sum is one, max is one, there has to be an only one
# this is compatible with both integer arrays, boolean arrays and float arrays
if
self
.
verbose
:
_logger
.
info
(
"Replacing %s with candidate number %d."
,
global_name
,
chosen
.
index
(
1
))
setattr
(
module
,
name
,
mutable
[
chosen
.
index
(
1
)])
else
:
if
mutable
.
return_mask
:
if
mutable
.
return_mask
and
self
.
verbose
:
_logger
.
info
(
"`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, "
\
"LayerChoice will not be replaced."
)
# remove unused parameters
...
...
@@ -113,7 +117,7 @@ class FixedArchitecture(Mutator):
self
.
replace_layer_choice
(
mutable
,
global_name
)
def
apply_fixed_architecture
(
model
,
fixed_arc
):
def
apply_fixed_architecture
(
model
,
fixed_arc
,
verbose
=
True
):
"""
Load architecture from `fixed_arc` and apply to model.
...
...
@@ -123,6 +127,8 @@ def apply_fixed_architecture(model, fixed_arc):
Model with mutables.
fixed_arc : str or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
-------
...
...
@@ -133,7 +139,7 @@ def apply_fixed_architecture(model, fixed_arc):
if
isinstance
(
fixed_arc
,
str
):
with
open
(
fixed_arc
)
as
f
:
fixed_arc
=
json
.
load
(
f
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
,
verbose
)
architecture
.
reset
()
# for the convenience of parameters counting
...
...
setup.py
View file @
d6791c2b
...
...
@@ -48,6 +48,7 @@ from distutils.command.clean import clean
import
glob
import
os
import
shutil
import
sys
import
setuptools
from
setuptools.command.develop
import
develop
...
...
@@ -131,6 +132,8 @@ def _find_python_packages():
def
_find_node_files
():
if
not
os
.
path
.
exists
(
'nni_node'
):
if
release
and
'built_ts'
not
in
sys
.
argv
:
sys
.
exit
(
'ERROR: To build a release version, run "python setup.py built_ts" first'
)
return
[]
files
=
[]
for
dirpath
,
dirnames
,
filenames
in
os
.
walk
(
'nni_node'
):
...
...
@@ -140,6 +143,9 @@ def _find_node_files():
files
.
remove
(
'__init__.py'
)
return
sorted
(
files
)
def
_using_conda_or_virtual_environment
():
return
sys
.
prefix
!=
sys
.
base_prefix
or
os
.
path
.
isdir
(
os
.
path
.
join
(
sys
.
prefix
,
'conda-meta'
))
class
BuildTs
(
Command
):
description
=
'build TypeScript modules'
...
...
@@ -163,8 +169,21 @@ class Build(build):
super
().
run
()
class
Develop
(
develop
):
user_options
=
develop
.
user_options
+
[
(
'no-user'
,
None
,
'Prevent automatically adding "--user"'
)
]
boolean_options
=
develop
.
boolean_options
+
[
'no-user'
]
def
initialize_options
(
self
):
super
().
initialize_options
()
self
.
no_user
=
None
def
finalize_options
(
self
):
self
.
user
=
True
# always use `develop --user`
# if `--user` or `--no-user` is explicitly set, do nothing
# otherwise activate `--user` if using system python
if
not
self
.
user
and
not
self
.
no_user
:
self
.
user
=
not
_using_conda_or_virtual_environment
()
super
().
finalize_options
()
def
run
(
self
):
...
...
setup_ts.py
View file @
d6791c2b
...
...
@@ -131,7 +131,7 @@ def prepare_nni_node():
node_src
=
Path
(
'toolchain/node'
,
node_executable_in_tarball
)
node_dst
=
Path
(
'nni_node'
,
node_executable
)
shutil
.
copy
file
(
node_src
,
node_dst
)
shutil
.
copy
(
node_src
,
node_dst
)
def
compile_ts
():
...
...
ts/nni_manager/yarn.lock
View file @
d6791c2b
...
...
@@ -1336,7 +1336,7 @@ debug@^3.1.0:
dependencies:
ms "^2.1.1"
debuglog@*,
debuglog@^1.0.1:
debuglog@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
...
...
@@ -2392,7 +2392,7 @@ import-lazy@^2.1.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
imurmurhash@*,
imurmurhash@^0.1.4:
imurmurhash@^0.1.4:
version "0.1.4"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
integrity sha1-khi5srkoojixPcT7a21XbyMUU+o=
...
...
@@ -3074,11 +3074,6 @@ lockfile@^1.0.4:
dependencies:
signal-exit "^3.0.2"
lodash._baseindexof@*:
version "3.1.0"
resolved "https://registry.yarnpkg.com/lodash._baseindexof/-/lodash._baseindexof-3.1.0.tgz#fe52b53a1c6761e42618d654e4a25789ed61822c"
integrity sha1-/lK1OhxnYeQmGNZU5KJXie1hgiw=
lodash._baseuniq@~4.6.0:
version "4.6.0"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
...
...
@@ -3086,32 +3081,10 @@ lodash._baseuniq@~4.6.0:
lodash._createset "~4.0.0"
lodash._root "~3.0.0"
lodash._bindcallback@*:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._bindcallback/-/lodash._bindcallback-3.0.1.tgz#e531c27644cf8b57a99e17ed95b35c748789392e"
integrity sha1-5THCdkTPi1epnhftlbNcdIeJOS4=
lodash._cacheindexof@*:
version "3.0.2"
resolved "https://registry.yarnpkg.com/lodash._cacheindexof/-/lodash._cacheindexof-3.0.2.tgz#3dc69ac82498d2ee5e3ce56091bafd2adc7bde92"
integrity sha1-PcaayCSY0u5ePOVgkbr9Ktx73pI=
lodash._createcache@*:
version "3.1.2"
resolved "https://registry.yarnpkg.com/lodash._createcache/-/lodash._createcache-3.1.2.tgz#56d6a064017625e79ebca6b8018e17440bdcf093"
integrity sha1-VtagZAF2JeeevKa4AY4XRAvc8JM=
dependencies:
lodash._getnative "^3.0.0"
lodash._createset@~4.0.0:
version "4.0.3"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
lodash._getnative@*, lodash._getnative@^3.0.0:
version "3.9.1"
resolved "https://registry.yarnpkg.com/lodash._getnative/-/lodash._getnative-3.9.1.tgz#570bc7dede46d61cdcde687d65d3eecbaa3aaff5"
integrity sha1-VwvH3t5G1hzc3mh9ZdPuy6o6r/U=
lodash._root@~3.0.0:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
...
...
@@ -3160,11 +3133,6 @@ lodash.pick@^4.4.0:
version "4.4.0"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
lodash.restparam@*:
version "3.6.1"
resolved "https://registry.yarnpkg.com/lodash.restparam/-/lodash.restparam-3.6.1.tgz#936a4e309ef330a7645ed4145986c85ae5b20805"
integrity sha1-k2pOMJ7zMKdkXtQUWYbIWuWyCAU=
lodash.unescape@4.0.1:
version "4.0.1"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment