Commit d6791c2b authored by Yuge Zhang's avatar Yuge Zhang
Browse files

Merge branch 'master' into dev-retiarii

parents 19726d4d 16dc45b1
...@@ -63,6 +63,7 @@ class SPOSSupernetTrainer(Trainer): ...@@ -63,6 +63,7 @@ class SPOSSupernetTrainer(Trainer):
self.model.train() self.model.train()
meters = AverageMeterGroup() meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader): for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.mutator.reset() self.mutator.reset()
logits = self.model(x) logits = self.model(x)
...@@ -82,6 +83,7 @@ class SPOSSupernetTrainer(Trainer): ...@@ -82,6 +83,7 @@ class SPOSSupernetTrainer(Trainer):
meters = AverageMeterGroup() meters = AverageMeterGroup()
with torch.no_grad(): with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader): for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
self.mutator.reset() self.mutator.reset()
logits = self.model(x) logits = self.model(x)
loss = self.loss(logits, y) loss = self.loss(logits, y)
......
...@@ -24,11 +24,14 @@ class FixedArchitecture(Mutator): ...@@ -24,11 +24,14 @@ class FixedArchitecture(Mutator):
Preloaded architecture object. Preloaded architecture object.
strict : bool strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once. Force everything that appears in ``fixed_arc`` to be used at least once.
verbose : bool
Print log messages if set to True
""" """
def __init__(self, model, fixed_arc, strict=True): def __init__(self, model, fixed_arc, strict=True, verbose=True):
super().__init__(model) super().__init__(model)
self._fixed_arc = fixed_arc self._fixed_arc = fixed_arc
self.verbose = verbose
mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)]) mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
fixed_arc_keys = set(self._fixed_arc.keys()) fixed_arc_keys = set(self._fixed_arc.keys())
...@@ -99,10 +102,11 @@ class FixedArchitecture(Mutator): ...@@ -99,10 +102,11 @@ class FixedArchitecture(Mutator):
if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask: if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask:
# sum is one, max is one, there has to be an only one # sum is one, max is one, there has to be an only one
# this is compatible with both integer arrays, boolean arrays and float arrays # this is compatible with both integer arrays, boolean arrays and float arrays
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1)) if self.verbose:
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1))
setattr(module, name, mutable[chosen.index(1)]) setattr(module, name, mutable[chosen.index(1)])
else: else:
if mutable.return_mask: if mutable.return_mask and self.verbose:
_logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \ _logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \
"LayerChoice will not be replaced.") "LayerChoice will not be replaced.")
# remove unused parameters # remove unused parameters
...@@ -113,7 +117,7 @@ class FixedArchitecture(Mutator): ...@@ -113,7 +117,7 @@ class FixedArchitecture(Mutator):
self.replace_layer_choice(mutable, global_name) self.replace_layer_choice(mutable, global_name)
def apply_fixed_architecture(model, fixed_arc): def apply_fixed_architecture(model, fixed_arc, verbose=True):
""" """
Load architecture from `fixed_arc` and apply to model. Load architecture from `fixed_arc` and apply to model.
...@@ -123,6 +127,8 @@ def apply_fixed_architecture(model, fixed_arc): ...@@ -123,6 +127,8 @@ def apply_fixed_architecture(model, fixed_arc):
Model with mutables. Model with mutables.
fixed_arc : str or dict fixed_arc : str or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture. Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns Returns
------- -------
...@@ -133,7 +139,7 @@ def apply_fixed_architecture(model, fixed_arc): ...@@ -133,7 +139,7 @@ def apply_fixed_architecture(model, fixed_arc):
if isinstance(fixed_arc, str): if isinstance(fixed_arc, str):
with open(fixed_arc) as f: with open(fixed_arc) as f:
fixed_arc = json.load(f) fixed_arc = json.load(f)
architecture = FixedArchitecture(model, fixed_arc) architecture = FixedArchitecture(model, fixed_arc, verbose)
architecture.reset() architecture.reset()
# for the convenience of parameters counting # for the convenience of parameters counting
......
...@@ -48,6 +48,7 @@ from distutils.command.clean import clean ...@@ -48,6 +48,7 @@ from distutils.command.clean import clean
import glob import glob
import os import os
import shutil import shutil
import sys
import setuptools import setuptools
from setuptools.command.develop import develop from setuptools.command.develop import develop
...@@ -131,6 +132,8 @@ def _find_python_packages(): ...@@ -131,6 +132,8 @@ def _find_python_packages():
def _find_node_files(): def _find_node_files():
if not os.path.exists('nni_node'): if not os.path.exists('nni_node'):
if release and 'built_ts' not in sys.argv:
sys.exit('ERROR: To build a release version, run "python setup.py built_ts" first')
return [] return []
files = [] files = []
for dirpath, dirnames, filenames in os.walk('nni_node'): for dirpath, dirnames, filenames in os.walk('nni_node'):
...@@ -140,6 +143,9 @@ def _find_node_files(): ...@@ -140,6 +143,9 @@ def _find_node_files():
files.remove('__init__.py') files.remove('__init__.py')
return sorted(files) return sorted(files)
def _using_conda_or_virtual_environment():
return sys.prefix != sys.base_prefix or os.path.isdir(os.path.join(sys.prefix, 'conda-meta'))
class BuildTs(Command): class BuildTs(Command):
description = 'build TypeScript modules' description = 'build TypeScript modules'
...@@ -163,8 +169,21 @@ class Build(build): ...@@ -163,8 +169,21 @@ class Build(build):
super().run() super().run()
class Develop(develop): class Develop(develop):
user_options = develop.user_options + [
('no-user', None, 'Prevent automatically adding "--user"')
]
boolean_options = develop.boolean_options + ['no-user']
def initialize_options(self):
super().initialize_options()
self.no_user = None
def finalize_options(self): def finalize_options(self):
self.user = True # always use `develop --user` # if `--user` or `--no-user` is explicitly set, do nothing
# otherwise activate `--user` if using system python
if not self.user and not self.no_user:
self.user = not _using_conda_or_virtual_environment()
super().finalize_options() super().finalize_options()
def run(self): def run(self):
......
...@@ -131,7 +131,7 @@ def prepare_nni_node(): ...@@ -131,7 +131,7 @@ def prepare_nni_node():
node_src = Path('toolchain/node', node_executable_in_tarball) node_src = Path('toolchain/node', node_executable_in_tarball)
node_dst = Path('nni_node', node_executable) node_dst = Path('nni_node', node_executable)
shutil.copyfile(node_src, node_dst) shutil.copy(node_src, node_dst)
def compile_ts(): def compile_ts():
......
...@@ -1336,7 +1336,7 @@ debug@^3.1.0: ...@@ -1336,7 +1336,7 @@ debug@^3.1.0:
dependencies: dependencies:
ms "^2.1.1" ms "^2.1.1"
debuglog@*, debuglog@^1.0.1: debuglog@^1.0.1:
version "1.0.1" version "1.0.1"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492" resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
...@@ -2392,7 +2392,7 @@ import-lazy@^2.1.0: ...@@ -2392,7 +2392,7 @@ import-lazy@^2.1.0:
version "2.1.0" version "2.1.0"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43" resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
imurmurhash@*, imurmurhash@^0.1.4: imurmurhash@^0.1.4:
version "0.1.4" version "0.1.4"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea" resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
integrity sha1-khi5srkoojixPcT7a21XbyMUU+o= integrity sha1-khi5srkoojixPcT7a21XbyMUU+o=
...@@ -3074,11 +3074,6 @@ lockfile@^1.0.4: ...@@ -3074,11 +3074,6 @@ lockfile@^1.0.4:
dependencies: dependencies:
signal-exit "^3.0.2" signal-exit "^3.0.2"
lodash._baseindexof@*:
version "3.1.0"
resolved "https://registry.yarnpkg.com/lodash._baseindexof/-/lodash._baseindexof-3.1.0.tgz#fe52b53a1c6761e42618d654e4a25789ed61822c"
integrity sha1-/lK1OhxnYeQmGNZU5KJXie1hgiw=
lodash._baseuniq@~4.6.0: lodash._baseuniq@~4.6.0:
version "4.6.0" version "4.6.0"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8" resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
...@@ -3086,32 +3081,10 @@ lodash._baseuniq@~4.6.0: ...@@ -3086,32 +3081,10 @@ lodash._baseuniq@~4.6.0:
lodash._createset "~4.0.0" lodash._createset "~4.0.0"
lodash._root "~3.0.0" lodash._root "~3.0.0"
lodash._bindcallback@*:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._bindcallback/-/lodash._bindcallback-3.0.1.tgz#e531c27644cf8b57a99e17ed95b35c748789392e"
integrity sha1-5THCdkTPi1epnhftlbNcdIeJOS4=
lodash._cacheindexof@*:
version "3.0.2"
resolved "https://registry.yarnpkg.com/lodash._cacheindexof/-/lodash._cacheindexof-3.0.2.tgz#3dc69ac82498d2ee5e3ce56091bafd2adc7bde92"
integrity sha1-PcaayCSY0u5ePOVgkbr9Ktx73pI=
lodash._createcache@*:
version "3.1.2"
resolved "https://registry.yarnpkg.com/lodash._createcache/-/lodash._createcache-3.1.2.tgz#56d6a064017625e79ebca6b8018e17440bdcf093"
integrity sha1-VtagZAF2JeeevKa4AY4XRAvc8JM=
dependencies:
lodash._getnative "^3.0.0"
lodash._createset@~4.0.0: lodash._createset@~4.0.0:
version "4.0.3" version "4.0.3"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26" resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
lodash._getnative@*, lodash._getnative@^3.0.0:
version "3.9.1"
resolved "https://registry.yarnpkg.com/lodash._getnative/-/lodash._getnative-3.9.1.tgz#570bc7dede46d61cdcde687d65d3eecbaa3aaff5"
integrity sha1-VwvH3t5G1hzc3mh9ZdPuy6o6r/U=
lodash._root@~3.0.0: lodash._root@~3.0.0:
version "3.0.1" version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692" resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
...@@ -3160,11 +3133,6 @@ lodash.pick@^4.4.0: ...@@ -3160,11 +3133,6 @@ lodash.pick@^4.4.0:
version "4.4.0" version "4.4.0"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3" resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
lodash.restparam@*:
version "3.6.1"
resolved "https://registry.yarnpkg.com/lodash.restparam/-/lodash.restparam-3.6.1.tgz#936a4e309ef330a7645ed4145986c85ae5b20805"
integrity sha1-k2pOMJ7zMKdkXtQUWYbIWuWyCAU=
lodash.unescape@4.0.1: lodash.unescape@4.0.1:
version "4.0.1" version "4.0.1"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c" resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment