Commit d6791c2b authored by Yuge Zhang's avatar Yuge Zhang
Browse files

Merge branch 'master' into dev-retiarii

parents 19726d4d 16dc45b1
......@@ -63,6 +63,7 @@ class SPOSSupernetTrainer(Trainer):
self.model.train()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
self.mutator.reset()
logits = self.model(x)
......@@ -82,6 +83,7 @@ class SPOSSupernetTrainer(Trainer):
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
......
......@@ -24,11 +24,14 @@ class FixedArchitecture(Mutator):
Preloaded architecture object.
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
verbose : bool
Print log messages if set to True
"""
def __init__(self, model, fixed_arc, strict=True):
def __init__(self, model, fixed_arc, strict=True, verbose=True):
super().__init__(model)
self._fixed_arc = fixed_arc
self.verbose = verbose
mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
fixed_arc_keys = set(self._fixed_arc.keys())
......@@ -99,10 +102,11 @@ class FixedArchitecture(Mutator):
if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask:
# sum is one, max is one, there has to be an only one
# this is compatible with both integer arrays, boolean arrays and float arrays
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1))
if self.verbose:
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1))
setattr(module, name, mutable[chosen.index(1)])
else:
if mutable.return_mask:
if mutable.return_mask and self.verbose:
_logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \
"LayerChoice will not be replaced.")
# remove unused parameters
......@@ -113,7 +117,7 @@ class FixedArchitecture(Mutator):
self.replace_layer_choice(mutable, global_name)
def apply_fixed_architecture(model, fixed_arc):
def apply_fixed_architecture(model, fixed_arc, verbose=True):
"""
Load architecture from `fixed_arc` and apply to model.
......@@ -123,6 +127,8 @@ def apply_fixed_architecture(model, fixed_arc):
Model with mutables.
fixed_arc : str or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
-------
......@@ -133,7 +139,7 @@ def apply_fixed_architecture(model, fixed_arc):
if isinstance(fixed_arc, str):
with open(fixed_arc) as f:
fixed_arc = json.load(f)
architecture = FixedArchitecture(model, fixed_arc)
architecture = FixedArchitecture(model, fixed_arc, verbose)
architecture.reset()
# for the convenience of parameters counting
......
......@@ -48,6 +48,7 @@ from distutils.command.clean import clean
import glob
import os
import shutil
import sys
import setuptools
from setuptools.command.develop import develop
......@@ -131,6 +132,8 @@ def _find_python_packages():
def _find_node_files():
if not os.path.exists('nni_node'):
if release and 'built_ts' not in sys.argv:
sys.exit('ERROR: To build a release version, run "python setup.py built_ts" first')
return []
files = []
for dirpath, dirnames, filenames in os.walk('nni_node'):
......@@ -140,6 +143,9 @@ def _find_node_files():
files.remove('__init__.py')
return sorted(files)
def _using_conda_or_virtual_environment():
return sys.prefix != sys.base_prefix or os.path.isdir(os.path.join(sys.prefix, 'conda-meta'))
class BuildTs(Command):
description = 'build TypeScript modules'
......@@ -163,8 +169,21 @@ class Build(build):
super().run()
class Develop(develop):
user_options = develop.user_options + [
('no-user', None, 'Prevent automatically adding "--user"')
]
boolean_options = develop.boolean_options + ['no-user']
def initialize_options(self):
super().initialize_options()
self.no_user = None
def finalize_options(self):
self.user = True # always use `develop --user`
# if `--user` or `--no-user` is explicitly set, do nothing
# otherwise activate `--user` if using system python
if not self.user and not self.no_user:
self.user = not _using_conda_or_virtual_environment()
super().finalize_options()
def run(self):
......
......@@ -131,7 +131,7 @@ def prepare_nni_node():
node_src = Path('toolchain/node', node_executable_in_tarball)
node_dst = Path('nni_node', node_executable)
shutil.copyfile(node_src, node_dst)
shutil.copy(node_src, node_dst)
def compile_ts():
......
......@@ -1336,7 +1336,7 @@ debug@^3.1.0:
dependencies:
ms "^2.1.1"
debuglog@*, debuglog@^1.0.1:
debuglog@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
......@@ -2392,7 +2392,7 @@ import-lazy@^2.1.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
imurmurhash@*, imurmurhash@^0.1.4:
imurmurhash@^0.1.4:
version "0.1.4"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
integrity sha1-khi5srkoojixPcT7a21XbyMUU+o=
......@@ -3074,11 +3074,6 @@ lockfile@^1.0.4:
dependencies:
signal-exit "^3.0.2"
lodash._baseindexof@*:
version "3.1.0"
resolved "https://registry.yarnpkg.com/lodash._baseindexof/-/lodash._baseindexof-3.1.0.tgz#fe52b53a1c6761e42618d654e4a25789ed61822c"
integrity sha1-/lK1OhxnYeQmGNZU5KJXie1hgiw=
lodash._baseuniq@~4.6.0:
version "4.6.0"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
......@@ -3086,32 +3081,10 @@ lodash._baseuniq@~4.6.0:
lodash._createset "~4.0.0"
lodash._root "~3.0.0"
lodash._bindcallback@*:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._bindcallback/-/lodash._bindcallback-3.0.1.tgz#e531c27644cf8b57a99e17ed95b35c748789392e"
integrity sha1-5THCdkTPi1epnhftlbNcdIeJOS4=
lodash._cacheindexof@*:
version "3.0.2"
resolved "https://registry.yarnpkg.com/lodash._cacheindexof/-/lodash._cacheindexof-3.0.2.tgz#3dc69ac82498d2ee5e3ce56091bafd2adc7bde92"
integrity sha1-PcaayCSY0u5ePOVgkbr9Ktx73pI=
lodash._createcache@*:
version "3.1.2"
resolved "https://registry.yarnpkg.com/lodash._createcache/-/lodash._createcache-3.1.2.tgz#56d6a064017625e79ebca6b8018e17440bdcf093"
integrity sha1-VtagZAF2JeeevKa4AY4XRAvc8JM=
dependencies:
lodash._getnative "^3.0.0"
lodash._createset@~4.0.0:
version "4.0.3"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
lodash._getnative@*, lodash._getnative@^3.0.0:
version "3.9.1"
resolved "https://registry.yarnpkg.com/lodash._getnative/-/lodash._getnative-3.9.1.tgz#570bc7dede46d61cdcde687d65d3eecbaa3aaff5"
integrity sha1-VwvH3t5G1hzc3mh9ZdPuy6o6r/U=
lodash._root@~3.0.0:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
......@@ -3160,11 +3133,6 @@ lodash.pick@^4.4.0:
version "4.4.0"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
lodash.restparam@*:
version "3.6.1"
resolved "https://registry.yarnpkg.com/lodash.restparam/-/lodash.restparam-3.6.1.tgz#936a4e309ef330a7645ed4145986c85ae5b20805"
integrity sha1-k2pOMJ7zMKdkXtQUWYbIWuWyCAU=
lodash.unescape@4.0.1:
version "4.0.1"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment