Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d6791c2b
Commit
d6791c2b
authored
Nov 05, 2020
by
Yuge Zhang
Browse files
Merge branch 'master' into dev-retiarii
parents
19726d4d
16dc45b1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
36 additions
and
41 deletions
+36
-41
nni/algorithms/nas/pytorch/spos/trainer.py
nni/algorithms/nas/pytorch/spos/trainer.py
+2
-0
nni/nas/pytorch/fixed.py
nni/nas/pytorch/fixed.py
+11
-5
setup.py
setup.py
+20
-1
setup_ts.py
setup_ts.py
+1
-1
ts/nni_manager/yarn.lock
ts/nni_manager/yarn.lock
+2
-34
No files found.
nni/algorithms/nas/pytorch/spos/trainer.py
View file @
d6791c2b
...
@@ -63,6 +63,7 @@ class SPOSSupernetTrainer(Trainer):
...
@@ -63,6 +63,7 @@ class SPOSSupernetTrainer(Trainer):
self
.
model
.
train
()
self
.
model
.
train
()
meters
=
AverageMeterGroup
()
meters
=
AverageMeterGroup
()
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
train_loader
):
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
train_loader
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
self
.
mutator
.
reset
()
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
logits
=
self
.
model
(
x
)
...
@@ -82,6 +83,7 @@ class SPOSSupernetTrainer(Trainer):
...
@@ -82,6 +83,7 @@ class SPOSSupernetTrainer(Trainer):
meters
=
AverageMeterGroup
()
meters
=
AverageMeterGroup
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
valid_loader
):
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
valid_loader
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
self
.
mutator
.
reset
()
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
logits
=
self
.
model
(
x
)
loss
=
self
.
loss
(
logits
,
y
)
loss
=
self
.
loss
(
logits
,
y
)
...
...
nni/nas/pytorch/fixed.py
View file @
d6791c2b
...
@@ -24,11 +24,14 @@ class FixedArchitecture(Mutator):
...
@@ -24,11 +24,14 @@ class FixedArchitecture(Mutator):
Preloaded architecture object.
Preloaded architecture object.
strict : bool
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
Force everything that appears in ``fixed_arc`` to be used at least once.
verbose : bool
Print log messages if set to True
"""
"""
def
__init__
(
self
,
model
,
fixed_arc
,
strict
=
True
):
def
__init__
(
self
,
model
,
fixed_arc
,
strict
=
True
,
verbose
=
True
):
super
().
__init__
(
model
)
super
().
__init__
(
model
)
self
.
_fixed_arc
=
fixed_arc
self
.
_fixed_arc
=
fixed_arc
self
.
verbose
=
verbose
mutable_keys
=
set
([
mutable
.
key
for
mutable
in
self
.
mutables
if
not
isinstance
(
mutable
,
MutableScope
)])
mutable_keys
=
set
([
mutable
.
key
for
mutable
in
self
.
mutables
if
not
isinstance
(
mutable
,
MutableScope
)])
fixed_arc_keys
=
set
(
self
.
_fixed_arc
.
keys
())
fixed_arc_keys
=
set
(
self
.
_fixed_arc
.
keys
())
...
@@ -99,10 +102,11 @@ class FixedArchitecture(Mutator):
...
@@ -99,10 +102,11 @@ class FixedArchitecture(Mutator):
if
sum
(
chosen
)
==
1
and
max
(
chosen
)
==
1
and
not
mutable
.
return_mask
:
if
sum
(
chosen
)
==
1
and
max
(
chosen
)
==
1
and
not
mutable
.
return_mask
:
# sum is one, max is one, there has to be an only one
# sum is one, max is one, there has to be an only one
# this is compatible with both integer arrays, boolean arrays and float arrays
# this is compatible with both integer arrays, boolean arrays and float arrays
if
self
.
verbose
:
_logger
.
info
(
"Replacing %s with candidate number %d."
,
global_name
,
chosen
.
index
(
1
))
_logger
.
info
(
"Replacing %s with candidate number %d."
,
global_name
,
chosen
.
index
(
1
))
setattr
(
module
,
name
,
mutable
[
chosen
.
index
(
1
)])
setattr
(
module
,
name
,
mutable
[
chosen
.
index
(
1
)])
else
:
else
:
if
mutable
.
return_mask
:
if
mutable
.
return_mask
and
self
.
verbose
:
_logger
.
info
(
"`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, "
\
_logger
.
info
(
"`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, "
\
"LayerChoice will not be replaced."
)
"LayerChoice will not be replaced."
)
# remove unused parameters
# remove unused parameters
...
@@ -113,7 +117,7 @@ class FixedArchitecture(Mutator):
...
@@ -113,7 +117,7 @@ class FixedArchitecture(Mutator):
self
.
replace_layer_choice
(
mutable
,
global_name
)
self
.
replace_layer_choice
(
mutable
,
global_name
)
def
apply_fixed_architecture
(
model
,
fixed_arc
):
def
apply_fixed_architecture
(
model
,
fixed_arc
,
verbose
=
True
):
"""
"""
Load architecture from `fixed_arc` and apply to model.
Load architecture from `fixed_arc` and apply to model.
...
@@ -123,6 +127,8 @@ def apply_fixed_architecture(model, fixed_arc):
...
@@ -123,6 +127,8 @@ def apply_fixed_architecture(model, fixed_arc):
Model with mutables.
Model with mutables.
fixed_arc : str or dict
fixed_arc : str or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
Returns
-------
-------
...
@@ -133,7 +139,7 @@ def apply_fixed_architecture(model, fixed_arc):
...
@@ -133,7 +139,7 @@ def apply_fixed_architecture(model, fixed_arc):
if
isinstance
(
fixed_arc
,
str
):
if
isinstance
(
fixed_arc
,
str
):
with
open
(
fixed_arc
)
as
f
:
with
open
(
fixed_arc
)
as
f
:
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
json
.
load
(
f
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
,
verbose
)
architecture
.
reset
()
architecture
.
reset
()
# for the convenience of parameters counting
# for the convenience of parameters counting
...
...
setup.py
View file @
d6791c2b
...
@@ -48,6 +48,7 @@ from distutils.command.clean import clean
...
@@ -48,6 +48,7 @@ from distutils.command.clean import clean
import
glob
import
glob
import
os
import
os
import
shutil
import
shutil
import
sys
import
setuptools
import
setuptools
from
setuptools.command.develop
import
develop
from
setuptools.command.develop
import
develop
...
@@ -131,6 +132,8 @@ def _find_python_packages():
...
@@ -131,6 +132,8 @@ def _find_python_packages():
def
_find_node_files
():
def
_find_node_files
():
if
not
os
.
path
.
exists
(
'nni_node'
):
if
not
os
.
path
.
exists
(
'nni_node'
):
if
release
and
'built_ts'
not
in
sys
.
argv
:
sys
.
exit
(
'ERROR: To build a release version, run "python setup.py built_ts" first'
)
return
[]
return
[]
files
=
[]
files
=
[]
for
dirpath
,
dirnames
,
filenames
in
os
.
walk
(
'nni_node'
):
for
dirpath
,
dirnames
,
filenames
in
os
.
walk
(
'nni_node'
):
...
@@ -140,6 +143,9 @@ def _find_node_files():
...
@@ -140,6 +143,9 @@ def _find_node_files():
files
.
remove
(
'__init__.py'
)
files
.
remove
(
'__init__.py'
)
return
sorted
(
files
)
return
sorted
(
files
)
def
_using_conda_or_virtual_environment
():
return
sys
.
prefix
!=
sys
.
base_prefix
or
os
.
path
.
isdir
(
os
.
path
.
join
(
sys
.
prefix
,
'conda-meta'
))
class
BuildTs
(
Command
):
class
BuildTs
(
Command
):
description
=
'build TypeScript modules'
description
=
'build TypeScript modules'
...
@@ -163,8 +169,21 @@ class Build(build):
...
@@ -163,8 +169,21 @@ class Build(build):
super
().
run
()
super
().
run
()
class
Develop
(
develop
):
class
Develop
(
develop
):
user_options
=
develop
.
user_options
+
[
(
'no-user'
,
None
,
'Prevent automatically adding "--user"'
)
]
boolean_options
=
develop
.
boolean_options
+
[
'no-user'
]
def
initialize_options
(
self
):
super
().
initialize_options
()
self
.
no_user
=
None
def
finalize_options
(
self
):
def
finalize_options
(
self
):
self
.
user
=
True
# always use `develop --user`
# if `--user` or `--no-user` is explicitly set, do nothing
# otherwise activate `--user` if using system python
if
not
self
.
user
and
not
self
.
no_user
:
self
.
user
=
not
_using_conda_or_virtual_environment
()
super
().
finalize_options
()
super
().
finalize_options
()
def
run
(
self
):
def
run
(
self
):
...
...
setup_ts.py
View file @
d6791c2b
...
@@ -131,7 +131,7 @@ def prepare_nni_node():
...
@@ -131,7 +131,7 @@ def prepare_nni_node():
node_src
=
Path
(
'toolchain/node'
,
node_executable_in_tarball
)
node_src
=
Path
(
'toolchain/node'
,
node_executable_in_tarball
)
node_dst
=
Path
(
'nni_node'
,
node_executable
)
node_dst
=
Path
(
'nni_node'
,
node_executable
)
shutil
.
copy
file
(
node_src
,
node_dst
)
shutil
.
copy
(
node_src
,
node_dst
)
def
compile_ts
():
def
compile_ts
():
...
...
ts/nni_manager/yarn.lock
View file @
d6791c2b
...
@@ -1336,7 +1336,7 @@ debug@^3.1.0:
...
@@ -1336,7 +1336,7 @@ debug@^3.1.0:
dependencies:
dependencies:
ms "^2.1.1"
ms "^2.1.1"
debuglog@*,
debuglog@^1.0.1:
debuglog@^1.0.1:
version "1.0.1"
version "1.0.1"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
...
@@ -2392,7 +2392,7 @@ import-lazy@^2.1.0:
...
@@ -2392,7 +2392,7 @@ import-lazy@^2.1.0:
version "2.1.0"
version "2.1.0"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
imurmurhash@*,
imurmurhash@^0.1.4:
imurmurhash@^0.1.4:
version "0.1.4"
version "0.1.4"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
integrity sha1-khi5srkoojixPcT7a21XbyMUU+o=
integrity sha1-khi5srkoojixPcT7a21XbyMUU+o=
...
@@ -3074,11 +3074,6 @@ lockfile@^1.0.4:
...
@@ -3074,11 +3074,6 @@ lockfile@^1.0.4:
dependencies:
dependencies:
signal-exit "^3.0.2"
signal-exit "^3.0.2"
lodash._baseindexof@*:
version "3.1.0"
resolved "https://registry.yarnpkg.com/lodash._baseindexof/-/lodash._baseindexof-3.1.0.tgz#fe52b53a1c6761e42618d654e4a25789ed61822c"
integrity sha1-/lK1OhxnYeQmGNZU5KJXie1hgiw=
lodash._baseuniq@~4.6.0:
lodash._baseuniq@~4.6.0:
version "4.6.0"
version "4.6.0"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
...
@@ -3086,32 +3081,10 @@ lodash._baseuniq@~4.6.0:
...
@@ -3086,32 +3081,10 @@ lodash._baseuniq@~4.6.0:
lodash._createset "~4.0.0"
lodash._createset "~4.0.0"
lodash._root "~3.0.0"
lodash._root "~3.0.0"
lodash._bindcallback@*:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._bindcallback/-/lodash._bindcallback-3.0.1.tgz#e531c27644cf8b57a99e17ed95b35c748789392e"
integrity sha1-5THCdkTPi1epnhftlbNcdIeJOS4=
lodash._cacheindexof@*:
version "3.0.2"
resolved "https://registry.yarnpkg.com/lodash._cacheindexof/-/lodash._cacheindexof-3.0.2.tgz#3dc69ac82498d2ee5e3ce56091bafd2adc7bde92"
integrity sha1-PcaayCSY0u5ePOVgkbr9Ktx73pI=
lodash._createcache@*:
version "3.1.2"
resolved "https://registry.yarnpkg.com/lodash._createcache/-/lodash._createcache-3.1.2.tgz#56d6a064017625e79ebca6b8018e17440bdcf093"
integrity sha1-VtagZAF2JeeevKa4AY4XRAvc8JM=
dependencies:
lodash._getnative "^3.0.0"
lodash._createset@~4.0.0:
lodash._createset@~4.0.0:
version "4.0.3"
version "4.0.3"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
lodash._getnative@*, lodash._getnative@^3.0.0:
version "3.9.1"
resolved "https://registry.yarnpkg.com/lodash._getnative/-/lodash._getnative-3.9.1.tgz#570bc7dede46d61cdcde687d65d3eecbaa3aaff5"
integrity sha1-VwvH3t5G1hzc3mh9ZdPuy6o6r/U=
lodash._root@~3.0.0:
lodash._root@~3.0.0:
version "3.0.1"
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
...
@@ -3160,11 +3133,6 @@ lodash.pick@^4.4.0:
...
@@ -3160,11 +3133,6 @@ lodash.pick@^4.4.0:
version "4.4.0"
version "4.4.0"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
lodash.restparam@*:
version "3.6.1"
resolved "https://registry.yarnpkg.com/lodash.restparam/-/lodash.restparam-3.6.1.tgz#936a4e309ef330a7645ed4145986c85ae5b20805"
integrity sha1-k2pOMJ7zMKdkXtQUWYbIWuWyCAU=
lodash.unescape@4.0.1:
lodash.unescape@4.0.1:
version "4.0.1"
version "4.0.1"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment