"vscode:/vscode.git/clone" did not exist on "e8deff5206fc0adb94b8203c726b0a6b3b84ffe0"
Unverified Commit a0fd0036 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #5036 from microsoft/promote-retiarii-to-nas

[DO NOT SQUASH] Promote retiarii to NAS
parents d6dcb483 bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import itertools
import logging
import random
import time
from typing import Any, Dict, List, Sequence, Optional
from nni.nas.execution import submit_models, query_available_resources, budget_exhausted
from nni.nas.mutable import InvalidMutation, Sampler
from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model, filter_model
_logger = logging.getLogger(__name__)
def grid_generator(search_space: Dict[Any, List[Any]], shuffle=True):
keys = list(search_space.keys())
search_space_values = copy.deepcopy(list(search_space.values()))
if shuffle:
for values in search_space_values:
random.shuffle(values)
for values in itertools.product(*search_space_values):
yield {key: value for key, value in zip(keys, values)}
def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500):
keys = list(search_space.keys())
history = set()
search_space_values = copy.deepcopy(list(search_space.values()))
while True:
selected: Optional[Sequence[int]] = None
for retry_count in range(retries):
selected = [random.choice(v) for v in search_space_values]
if not dedup:
break
selected = tuple(selected)
if selected not in history:
history.add(selected)
break
if retry_count + 1 == retries:
_logger.debug('Random generation has run out of patience. There is nothing to search. Exiting.')
return
assert selected is not None, 'Retry attempts exhausted.'
yield {key: value for key, value in zip(keys, selected)}
class GridSearch(BaseStrategy):
"""
Traverse the search space and try all the possible combinations one by one.
Parameters
----------
shuffle : bool
Shuffle the order in a candidate list, so that they are tried in a random order. Default: true.
"""
def __init__(self, shuffle=True):
self._polling_interval = 2.
self.shuffle = shuffle
def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in grid_generator(search_space, shuffle=self.shuffle):
_logger.debug('New model created. Waiting for resource. %s', str(sample))
while query_available_resources() <= 0:
if budget_exhausted():
return
time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample))
class _RandomSampler(Sampler):
def choice(self, candidates, mutator, model, index):
return random.choice(candidates)
class Random(BaseStrategy):
"""
Random search on the search space.
Parameters
----------
variational : bool
Do not dry run to get the full search space. Used when the search space has variational size or candidates. Default: false.
dedup : bool
Do not try the same configuration twice. When variational is true, deduplication is not supported. Default: true.
model_filter: Callable[[Model], bool]
Feed the model and return a bool. This will filter the models in search space and select which to submit.
"""
def __init__(self, variational=False, dedup=True, model_filter=None):
self.variational = variational
self.dedup = dedup
if variational and dedup:
raise ValueError('Dedup is not supported in variational mode.')
self.random_sampler = _RandomSampler()
self._polling_interval = 2.
self.filter = model_filter
def run(self, base_model, applied_mutators):
if self.variational:
_logger.info('Random search running in variational mode.')
sampler = _RandomSampler()
for mutator in applied_mutators:
mutator.bind_sampler(sampler)
while True:
avail_resource = query_available_resources()
if avail_resource > 0:
model = base_model
for mutator in applied_mutators:
model = mutator.apply(model)
_logger.debug('New model created. Applied mutators are: %s', str(applied_mutators))
if filter_model(self.filter, model):
submit_models(model)
elif budget_exhausted():
break
else:
time.sleep(self._polling_interval)
else:
_logger.info('Random search running in fixed size mode. Dedup: %s.', 'on' if self.dedup else 'off')
search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in random_generator(search_space, dedup=self.dedup):
_logger.debug('New model created. Waiting for resource. %s', str(sample))
while query_available_resources() <= 0:
if budget_exhausted():
return
time.sleep(self._polling_interval)
_logger.debug('Still waiting for resource.')
try:
model = get_targeted_model(base_model, applied_mutators, sample)
if filter_model(self.filter, model):
_logger.debug('Submitting model: %s', model)
submit_models(model)
except InvalidMutation as e:
_logger.warning(f'Invalid mutation: {e}. Skip.')
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import random
import string
from nni.nas import Sampler, utils
from nni.nas.execution.pytorch import codegen
from nni.nas.execution.pytorch.graph import BaseGraphData
from nni.nas.execution.common import get_mutation_summary
from .base import BaseStrategy
_logger = logging.getLogger(__name__)
class ChooseFirstSampler(Sampler):
def choice(self, candidates, mutator, model, index):
return candidates[0]
class _LocalDebugStrategy(BaseStrategy):
"""
This class is supposed to be used internally, for debugging trial mutation
"""
def run_one_model(self, model):
mutation_summary = get_mutation_summary(model)
graph_data = BaseGraphData(codegen.pytorch.model_to_pytorch_script(model), model.evaluator, mutation_summary) # type: ignore
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
model_cls = utils.import_(f'_generated_model.{random_str}._model')
graph_data.evaluator._execute(model_cls)
os.remove(file_name)
def run(self, base_model, applied_mutators):
_logger.info('local debug strategy has been started.')
model = base_model
_logger.debug('New model created. Applied mutators: %s', str(applied_mutators))
choose_first_sampler = ChooseFirstSampler()
for mutator in applied_mutators:
mutator.bind_sampler(choose_first_sampler)
model = mutator.apply(model)
# directly run models
self.run_one_model(model)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
import dataclasses
import logging
import random
import time
from nni.nas.execution import query_available_resources, submit_models
from nni.nas.execution.common import ModelStatus
from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model, filter_model
_logger = logging.getLogger(__name__)
@dataclasses.dataclass
class Individual:
"""
A class that represents an individual.
Holds two attributes, where ``x`` is the model and ``y`` is the metric (e.g., accuracy).
"""
x: dict
y: float
class RegularizedEvolution(BaseStrategy):
"""
Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image Classifier Architecture Search".
Parameters
----------
optimize_mode : str
Can be one of "maximize" and "minimize". Default: maximize.
population_size : int
The number of individuals to keep in the population. Default: 100.
cycles : int
The number of cycles (trials) the algorithm should run for. Default: 20000.
sample_size : int
The number of individuals that should participate in each tournament. Default: 25.
mutation_prob : float
Probability that mutation happens in each dim. Default: 0.05
on_failure : str
Can be one of "ignore" and "worst". If "ignore", simply give up the model and find a new one.
If "worst", mark the model as -inf (if maximize, inf if minimize), so that the algorithm "learns" to avoid such model.
Default: ignore.
model_filter: Callable[[Model], bool]
Feed the model and return a bool. This will filter the models in search space and select which to submit.
"""
def __init__(self, optimize_mode='maximize', population_size=100, sample_size=25, cycles=20000,
mutation_prob=0.05, on_failure='ignore', model_filter=None):
assert optimize_mode in ['maximize', 'minimize']
assert on_failure in ['ignore', 'worst']
assert sample_size < population_size
self.optimize_mode = optimize_mode
self.population_size = population_size
self.sample_size = sample_size
self.cycles = cycles
self.mutation_prob = mutation_prob
self.on_failure = on_failure
self._worst = float('-inf') if self.optimize_mode == 'maximize' else float('inf')
self._success_count = 0
self._population = collections.deque()
self._running_models = []
self._polling_interval = 2.
self.filter = model_filter
def random(self, search_space):
return {k: random.choice(v) for k, v in search_space.items()}
def mutate(self, parent, search_space):
child = {}
for k, v in parent.items():
if random.uniform(0, 1) < self.mutation_prob:
# NOTE: we do not exclude the original choice here for simplicity,
# which is slightly different from the original paper.
child[k] = random.choice(search_space[k])
else:
child[k] = v
return child
def best_parent(self):
samples = [p for p in self._population] # copy population
random.shuffle(samples)
samples = list(samples)[:self.sample_size]
if self.optimize_mode == 'maximize':
parent = max(samples, key=lambda sample: sample.y)
else:
parent = min(samples, key=lambda sample: sample.y)
return parent.x
def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
# Run the first population regardless concurrency
_logger.info('Initializing the first population.')
while len(self._population) + len(self._running_models) <= self.population_size:
# try to submit new models
while len(self._population) + len(self._running_models) < self.population_size:
config = self.random(search_space)
self._submit_config(config, base_model, applied_mutators)
# collect results
self._move_succeeded_models_to_population()
self._remove_failed_models_from_running_list()
time.sleep(self._polling_interval)
if len(self._population) >= self.population_size:
break
# Resource-aware mutation of models
_logger.info('Running mutations.')
while self._success_count + len(self._running_models) <= self.cycles:
# try to submit new models
while query_available_resources() > 0 and self._success_count + len(self._running_models) < self.cycles:
config = self.mutate(self.best_parent(), search_space)
self._submit_config(config, base_model, applied_mutators)
# collect results
self._move_succeeded_models_to_population()
self._remove_failed_models_from_running_list()
time.sleep(self._polling_interval)
if self._success_count >= self.cycles:
break
def _submit_config(self, config, base_model, mutators):
_logger.debug('Model submitted to running queue: %s', config)
model = get_targeted_model(base_model, mutators, config)
if not filter_model(self.filter, model):
if self.on_failure == "worst":
model.status = ModelStatus.Failed
self._running_models.append((config, model))
else:
submit_models(model)
self._running_models.append((config, model))
return model
def _move_succeeded_models_to_population(self):
completed_indices = []
for i, (config, model) in enumerate(self._running_models):
metric = None
if self.on_failure == 'worst' and model.status == ModelStatus.Failed:
metric = self._worst
elif model.status == ModelStatus.Trained:
metric = model.metric
if metric is not None:
individual = Individual(config, metric)
_logger.debug('Individual created: %s', str(individual))
self._population.append(individual)
if len(self._population) > self.population_size:
self._population.popleft()
completed_indices.append(i)
for i in completed_indices[::-1]:
# delete from end to start so that the index number will not be affected.
self._success_count += 1
self._running_models.pop(i)
def _remove_failed_models_from_running_list(self):
# This is only done when on_failure policy is set to "ignore".
# Otherwise, failed models will be treated as inf when processed.
if self.on_failure == 'ignore':
number_of_failed_models = len([g for g in self._running_models if g[1].status == ModelStatus.Failed])
self._running_models = [g for g in self._running_models if g[1].status != ModelStatus.Failed]
if number_of_failed_models > 0:
_logger.info('%d failed models are ignored. Will retry.', number_of_failed_models)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Wrappers of HPO tuners as NAS strategy."""
import logging
import time
from typing import Optional
from nni.nas import Sampler, submit_models, query_available_resources, is_stopped_exec, budget_exhausted
from .base import BaseStrategy
_logger = logging.getLogger(__name__)
class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'):
# Move import here to eliminate some warning messages about dill.
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
self.cur_sample: Optional[dict] = None
self.index: Optional[int] = None
self.total_parameters = {}
def update_sample_space(self, sample_space):
search_space = {}
for i, each in enumerate(sample_space):
search_space[str(i)] = {'_type': 'choice', '_value': each}
self.tpe_tuner.update_search_space(search_space)
def generate_samples(self, model_id):
self.cur_sample = self.tpe_tuner.generate_parameters(model_id)
self.total_parameters[model_id] = self.cur_sample
self.index = 0
def receive_result(self, model_id, result):
self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result)
def choice(self, candidates, mutator, model, index):
assert isinstance(self.index, int) and isinstance(self.cur_sample, dict)
chosen = self.cur_sample[str(self.index)]
self.index += 1
return chosen
class TPE(BaseStrategy):
"""
The Tree-structured Parzen Estimator (TPE) is a sequential model-based optimization (SMBO) approach.
Find the details in
`Algorithms for Hyper-Parameter Optimization <https://papers.nips.cc/paper/2011/file/86e8f7ab32cfd12577bc2619bc635690-Paper.pdf>`__.
SMBO methods sequentially construct models to approximate the performance of hyperparameters based on historical measurements,
and then subsequently choose new hyperparameters to test based on this model.
"""
def __init__(self):
self.tpe_sampler = TPESampler()
self.model_id = 0
self.running_models = {}
def run(self, base_model, applied_mutators):
sample_space = []
new_model = base_model
for mutator in applied_mutators:
recorded_candidates, new_model = mutator.dry_run(new_model)
sample_space.extend(recorded_candidates)
self.tpe_sampler.update_sample_space(sample_space)
_logger.info('TPE strategy has been started.')
while not budget_exhausted():
avail_resource = query_available_resources()
if avail_resource > 0:
model = base_model
_logger.debug('New model created. Applied mutators: %s', str(applied_mutators))
self.tpe_sampler.generate_samples(self.model_id)
for mutator in applied_mutators:
mutator.bind_sampler(self.tpe_sampler)
model = mutator.apply(model)
# run models
submit_models(model)
self.running_models[self.model_id] = model
self.model_id += 1
else:
time.sleep(2)
_logger.debug('num of running models: %d', len(self.running_models))
to_be_deleted = []
for _id, _model in self.running_models.items():
if is_stopped_exec(_model):
if _model.metric is not None:
self.tpe_sampler.receive_result(_id, _model.metric)
_logger.debug('tpe receive results: %d, %s', _id, _model.metric)
to_be_deleted.append(_id)
for _id in to_be_deleted:
del self.running_models[_id]
# alias for backward compatibility
TPEStrategy = TPE
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base import BaseStrategy
try:
from nni.nas.oneshot.pytorch.strategy import ( # pylint: disable=unused-import
DARTS, GumbelDARTS, Proxyless, ENAS, RandomOneShot
)
except ImportError as import_err:
_import_err = import_err
class ImportFailedStrategy(BaseStrategy):
def run(self, base_model, applied_mutators):
raise _import_err
# otherwise typing check will pointing to the wrong location
globals()['DARTS'] = ImportFailedStrategy
globals()['GumbelDARTS'] = ImportFailedStrategy
globals()['Proxyless'] = ImportFailedStrategy
globals()['ENAS'] = ImportFailedStrategy
globals()['RandomOneShot'] = ImportFailedStrategy
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Optional, Callable
from nni.nas.execution import query_available_resources
from .base import BaseStrategy
from .utils import dry_run_for_search_space
try:
has_tianshou = True
import torch
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy, PPOPolicy # pylint: disable=unused-import
from ._rl_impl import ModelEvaluationEnv, MultiThreadEnvWorker, Preprocessor, Actor, Critic
except ImportError:
has_tianshou = False
_logger = logging.getLogger(__name__)
class PolicyBasedRL(BaseStrategy):
"""
Algorithm for policy-based reinforcement learning.
This is a wrapper of algorithms provided in tianshou (PPO by default),
and can be easily customized with other algorithms that inherit ``BasePolicy``
(e.g., `REINFORCE <https://link.springer.com/content/pdf/10.1007/BF00992696.pdf>`__
as in `this paper <https://arxiv.org/abs/1611.01578>`__).
Parameters
----------
max_collect : int
How many times collector runs to collect trials for RL. Default 100.
trial_per_collect : int
How many trials (trajectories) each time collector collects.
After each collect, trainer will sample batch from replay buffer and do the update. Default: 20.
policy_fn : function
Takes :class:`ModelEvaluationEnv` as input and return a policy.
See :meth:`PolicyBasedRL._default_policy_fn` for an example.
"""
def __init__(self, max_collect: int = 100, trial_per_collect = 20,
policy_fn: Optional[Callable[['ModelEvaluationEnv'], 'BasePolicy']] = None):
if not has_tianshou:
raise ImportError('`tianshou` is required to run RL-based strategy. '
'Please use "pip install tianshou" to install it beforehand.')
self.policy_fn = policy_fn or self._default_policy_fn
self.max_collect = max_collect
self.trial_per_collect = trial_per_collect
@staticmethod
def _default_policy_fn(env):
net = Preprocessor(env.observation_space)
actor = Actor(env.action_space, net)
critic = Critic(net)
optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=1e-4)
return PPOPolicy(actor, critic, optim, torch.distributions.Categorical,
discount_factor=1., action_space=env.action_space)
def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
concurrency = query_available_resources()
env_fn = lambda: ModelEvaluationEnv(base_model, applied_mutators, search_space)
policy = self.policy_fn(env_fn())
env = BaseVectorEnv([env_fn for _ in range(concurrency)], MultiThreadEnvWorker)
collector = Collector(policy, env, VectorReplayBuffer(20000, len(env)))
for cur_collect in range(1, self.max_collect + 1):
_logger.info('Collect [%d] Running...', cur_collect)
result = collector.collect(n_episode=self.trial_per_collect)
_logger.info('Collect [%d] Result: %s', cur_collect, str(result))
policy.update(0, collector.buffer, batch_size=64, repeat=5)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
import logging
from typing import Dict, Any, List
from nni.nas.execution.common import Model
from nni.nas.mutable import Mutator, Sampler
_logger = logging.getLogger(__name__)
class _FixedSampler(Sampler):
def __init__(self, sample):
self.sample = sample
def choice(self, candidates, mutator, model, index):
return self.sample[(mutator, index)]
def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, List[Any]]:
search_space = collections.OrderedDict()
for mutator in mutators:
recorded_candidates, model = mutator.dry_run(model)
for i, candidates in enumerate(recorded_candidates):
search_space[(mutator, i)] = candidates
return search_space
def dry_run_for_formatted_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, Dict[Any, Any]]:
search_space = collections.OrderedDict()
for mutator in mutators:
recorded_candidates, model = mutator.dry_run(model)
if len(recorded_candidates) == 1:
search_space[mutator.label] = {'_type': 'choice', '_value': recorded_candidates[0]}
else:
for i, candidate in enumerate(recorded_candidates):
search_space[f'{mutator.label}_{i}'] = {'_type': 'choice', '_value': candidate}
return search_space
def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) -> Model:
sampler = _FixedSampler(sample)
model = base_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
return model
def filter_model(model_filter, ir_model):
if model_filter is not None:
_logger.debug(f'Check if model satisfies constraints.')
if model_filter(ir_model):
_logger.debug(f'Model satisfied. Submit the model.')
return True
else:
_logger.debug(f'Model unsatisfied. Discard the model.')
return False
else:
return True
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from tensorflow.keras import Model
from .mutables import Mutable, MutableScope, InputChoice
from .utils import StructuredMutableTreeNode
class BaseMutator(Model):
def __init__(self, model):
super().__init__()
self.__dict__['model'] = model
self._structured_mutables = self._parse_search_space(self.model)
def _parse_search_space(self, module, root=None, prefix='', memo=None, nested_detection=None):
if memo is None:
memo = set()
if root is None:
root = StructuredMutableTreeNode(None)
if module not in memo:
memo.add(module)
if isinstance(module, Mutable):
if nested_detection is not None:
raise RuntimeError('Cannot have nested search space. Error at {} in {}'
.format(module, nested_detection))
module.name = prefix
module.set_mutator(self)
root = root.add_child(module)
if not isinstance(module, MutableScope):
nested_detection = module
if isinstance(module, InputChoice):
for k in module.choose_from:
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
raise RuntimeError('"{}" required by "{}" not found in keys that appeared before, and is not NO_KEY.'
.format(k, module.key))
for submodule in module.layers:
if not isinstance(submodule, Model):
continue
submodule_prefix = prefix + ('.' if prefix else '') + submodule.name
self._parse_search_space(submodule, root, submodule_prefix, memo=memo, nested_detection=nested_detection)
return root
@property
def mutables(self):
return self._structured_mutables
def undedup_mutables(self):
return self._structured_mutables.traverse(deduplicate=False)
def call(self, *inputs):
raise RuntimeError('Call is undefined for mutators.')
def __setattr__(self, name, value):
if name == 'model':
raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
"include your network, as it will include all parameters in model into the mutator.")
return super().__setattr__(name, value)
def enter_mutable_scope(self, mutable_scope):
pass
def exit_mutable_scope(self, mutable_scope):
pass
def on_forward_layer_choice(self, mutable, *inputs):
raise NotImplementedError
def on_forward_input_choice(self, mutable, tensor_list):
raise NotImplementedError
def export(self):
raise NotImplementedError
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import OrderedDict
from tensorflow.keras import Model
from .utils import global_mutable_counting
_logger = logging.getLogger(__name__)
class Mutable(Model):
def __init__(self, key=None):
super().__init__()
if key is None:
self._key = '{}_{}'.format(type(self).__name__, global_mutable_counting())
elif isinstance(key, str):
self._key = key
else:
self._key = str(key)
_logger.warning('Key "%s" is not string, converted to string.', key)
self.init_hook = None
self.forward_hook = None
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")
def set_mutator(self, mutator):
if hasattr(self, 'mutator'):
raise RuntimeError('`set_mutator is called more than once. '
'Did you parse the search space multiple times? '
'Or did you apply multiple fixed architectures?')
self.mutator = mutator
def call(self, *inputs):
raise NotImplementedError('Method `call` of Mutable must be overridden')
def build(self, input_shape):
self._check_built()
@property
def key(self):
return self._key
@property
def name(self):
return self._name if hasattr(self, '_name') else self._key
@name.setter
def name(self, name):
self._name = name
def _check_built(self):
if not hasattr(self, 'mutator'):
raise ValueError(
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return '{} ({})'.format(self.name, self.key)
class MutableScope(Mutable):
def __call__(self, *args, **kwargs):
try:
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
self.mutator.exit_mutable_scope(self)
class LayerChoice(Mutable):
def __init__(self, op_candidates, reduction='sum', return_mask=False, key=None):
super().__init__(key=key)
self.names = []
if isinstance(op_candidates, OrderedDict):
for name in op_candidates:
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.names.append(name)
elif isinstance(op_candidates, list):
for i, _ in enumerate(op_candidates):
self.names.append(str(i))
else:
raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates)))
self.length = len(op_candidates)
self.choices = op_candidates
self.reduction = reduction
self.return_mask = return_mask
def call(self, *inputs):
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask:
return out, mask
return out
def build(self, input_shape):
self._check_built()
for op in self.choices:
op.build(input_shape)
def __len__(self):
return len(self.choices)
class InputChoice(Mutable):
NO_KEY = ''
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, reduction='sum', return_mask=False, key=None):
super().__init__(key=key)
assert n_candidates is not None or choose_from is not None, \
'At least one of `n_candidates` and `choose_from` must be not None.'
if choose_from is not None and n_candidates is None:
n_candidates = len(choose_from)
elif choose_from is None and n_candidates is not None:
choose_from = [self.NO_KEY] * n_candidates
assert n_candidates == len(choose_from), 'Number of candidates must be equal to the length of `choose_from`.'
assert n_candidates > 0, 'Number of candidates must be greater than 0.'
assert n_chosen is None or 0 <= n_chosen <= n_candidates, \
'Expected selected number must be None or no more than number of candidates.'
self.n_candidates = n_candidates
self.choose_from = choose_from.copy()
self.n_chosen = n_chosen
self.reduction = reduction
self.return_mask = return_mask
def call(self, optional_inputs):
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), \
'Optional input list must be a list, not a {}.'.format(type(optional_input_list))
assert len(optional_inputs) == self.n_candidates, \
'Length of the input list must be equal to number of candidates.'
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
if self.return_mask:
return out, mask
return out
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import tensorflow as tf
from .base_mutator import BaseMutator
_logger = logging.getLogger(__name__)
class Mutator(BaseMutator):
def __init__(self, model):
super().__init__(model)
self._cache = {}
def sample_search(self):
raise NotImplementedError('Method `sample_search` must be overridden')
def sample_final(self):
raise NotImplementedError('Method `sample_final` must be overriden for exporting')
def reset(self):
self._cache = self.sample_search()
def export(self):
return self.sample_final()
# TODO: status
# TODO: graph
def on_forward_layer_choice(self, mutable, *inputs):
mask = self._get_decision(mutable)
assert len(mask) == len(mutable), \
'Invalid mask, expected {} to be of length {}.'.format(mask, len(mutable))
out = self._select_with_mask(lambda choice: choice(*inputs), mutable.choices, mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list):
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \
'Invalid mask, expected {} to be of length {}.'.format(mask, mutable.n_candidates)
out = self._select_with_mask(lambda tensor: tensor, tensor_list, mask)
return self._tensor_reduction(mutable.reduction, out), mask
def _select_with_mask(self, map_fn, candidates, mask):
if mask.dtype.is_bool:
out = [map_fn(cand) for cand, m in zip(candidates, mask) if m]
elif mask.dtype.is_floating:
out = [map_fn(cand) * m for cand, m in zip(candidates, mask) if m]
else:
raise ValueError('Unrecognized mask, dtype is {}'.format(mask.dtype.name))
return out
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == 'none':
return tensor_list
if not tensor_list:
return None
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == 'sum':
return sum(tensor_list)
if reduction_type == 'mean':
return sum(tensor_list) / len(tensor_list)
if reduction_type == 'concat':
image_data_format = tf.keras.backend.image_data_format()
if image_data_format == "channels_first":
axis = 0
else:
axis = -1
return tf.concat(tensor_list, axis=axis) # pylint: disable=E1120,E1123
# pylint issue #3613
raise ValueError('Unrecognized reduction policy: "{}'.format(reduction_type))
def _get_decision(self, mutable):
if mutable.key not in self._cache:
raise ValueError('"{}" not found in decision cache.'.format(mutable.key))
result = self._cache[mutable.key]
_logger.debug('Decision %s: %s', mutable.key, result)
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
_counter = 0
def global_mutable_counting():
global _counter
_counter += 1
return _counter
class AverageMeter:
def __init__(self, name):
self.name = name
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val):
self.val = val
self.sum += val
self.count += 1
self.avg = self.sum / self.count
def __str__(self):
return '{name} {val:4f} ({avg:4f})'.format(**self.__dict__)
def summary(self):
return '{name}: {avg:4f}'.format(**self.__dict__)
class AverageMeterGroup:
def __init__(self):
self.meters = {}
def update(self, data):
for k, v in data.items():
if k not in self.meters:
self.meters[k] = AverageMeter(k)
self.meters[k].update(v)
def __str__(self):
return ' '.join(str(v) for v in self.meters.values())
def summary(self):
return ' '.join(v.summary() for v in self.meters.values())
class StructuredMutableTreeNode:
def __init__(self, mutable):
self.mutable = mutable
self.children = []
def add_child(self, mutable):
self.children.append(StructuredMutableTreeNode(mutable))
return self.children[-1]
def type(self):
return type(self.mutable)
def __iter__(self):
return self.traverse()
def traverse(self, order="pre", deduplicate=True, memo=None):
if memo is None:
memo = set()
assert order in ["pre", "post"]
if order == "pre":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
for child in self.children:
for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo):
yield m
if order == "post":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
def fill_zero_grads(grads, weights):
ret = []
for grad, weight in zip(grads, weights):
if grad is not None:
ret.append(grad)
else:
ret.append(tf.zeros_like(weight))
return ret
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .misc import *
from .serializer import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import inspect
import itertools
import warnings
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, List, Dict, cast
from pathlib import Path
from nni.common.hpo_utils import ParameterSpec
__all__ = [
'NoContextError', 'ContextStack', 'ModelNamespace', 'original_state_dict_hooks',
'uid', 'import_', 'reset_uid', 'get_module_name', 'get_importable_name', 'get_current_context',
'STATE_DICT_PY_MAPPING', 'STATE_DICT_PY_MAPPING_PARTIAL',
]
def import_(target: str, allow_none: bool = False) -> Any:
if target is None:
return None
path, identifier = target.rsplit('.', 1)
module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier)
_last_uid = defaultdict(int)
_DEFAULT_MODEL_NAMESPACE = 'model'
def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1
return _last_uid[namespace]
def reset_uid(namespace: str = 'default') -> None:
_last_uid[namespace] = 0
def get_module_name(cls_or_func):
module_name = cls_or_func.__module__
if module_name == '__main__':
# infer the module name with inspect
for frm in inspect.stack():
module = inspect.getmodule(frm[0])
if module is not None and module.__name__ == '__main__':
# main module found
main_file_path = Path(cast(str, inspect.getsourcefile(frm[0])))
if not Path().samefile(main_file_path.parent):
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem
break
if module_name == '__main__':
warnings.warn('Callstack exhausted but main module still not found. This will probably cause issues that the '
'function/class cannot be imported.')
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if f'{cls_or_func.__module__}.{cls_or_func.__name__}' == 'torch.nn.modules.rnn.LSTM':
module_name = cls_or_func.__module__
return module_name
def get_importable_name(cls, relocate_module=False):
module_name = get_module_name(cls) if relocate_module else cls.__module__
return module_name + '.' + cls.__name__
class NoContextError(Exception):
"""Exception raised when context is missing."""
pass
class ContextStack:
"""
This is to maintain a globally-accessible context environment that is visible to everywhere.
Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
get the corresponding value in the namespace.
Note that this is not multi-processing safe. Also, the values will get cleared for a new process.
"""
_stack: Dict[str, List[Any]] = defaultdict(list)
def __init__(self, key: str, value: Any):
self.key = key
self.value = value
def __enter__(self):
self.push(self.key, self.value)
return self
def __exit__(self, *args, **kwargs):
self.pop(self.key)
@classmethod
def push(cls, key: str, value: Any):
cls._stack[key].append(value)
@classmethod
def pop(cls, key: str) -> None:
cls._stack[key].pop()
@classmethod
def top(cls, key: str) -> Any:
if not cls._stack[key]:
raise NoContextError('Context is empty.')
return cls._stack[key][-1]
class ModelNamespace:
"""
To create an individual namespace for models:
1. to enable automatic numbering;
2. to trace general information (like creation of hyper-parameters) of model.
A namespace is bounded to a key. Namespace bounded to different keys are completed isolated.
Namespace can have sub-namespaces (with the same key). The numbering will be chained (e.g., ``model_1_4_2``).
"""
def __init__(self, key: str = _DEFAULT_MODEL_NAMESPACE):
# for example, key: "model_wrapper"
self.key = key
# the "path" of current name
# By default, it's ``[]``
# If a ``@model_wrapper`` is nested inside a model_wrapper, it will become something like ``[1, 3, 2]``.
# See ``__enter__``.
self.name_path: List[int] = []
# parameter specs.
# Currently only used trace calls of ModelParameterChoice.
self.parameter_specs: List[ParameterSpec] = []
def __enter__(self):
# For example, currently the top of stack is [1, 2, 2], and [1, 2, 2, 3] is used,
# the next thing up is [1, 2, 2, 4].
# `reset_uid` to count from zero for "model_wrapper_1_2_2_4"
try:
parent_context: 'ModelNamespace' = ModelNamespace.current_context(self.key)
next_uid = uid(parent_context._simple_name())
self.name_path = parent_context.name_path + [next_uid]
ContextStack.push(self.key, self)
reset_uid(self._simple_name())
except NoContextError:
# not found, no existing namespace
self.name_path = []
ContextStack.push(self.key, self)
reset_uid(self._simple_name())
def __exit__(self, *args, **kwargs):
ContextStack.pop(self.key)
def _simple_name(self) -> str:
return self.key + ''.join(['_' + str(k) for k in self.name_path])
def __repr__(self):
return f'ModelNamespace(name={self._simple_name()}, num_specs={len(self.parameter_specs)})'
# Access the current context in the model #
@staticmethod
def current_context(key: str = _DEFAULT_MODEL_NAMESPACE) -> 'ModelNamespace':
"""Get the current context in key."""
try:
return ContextStack.top(key)
except NoContextError:
raise NoContextError('ModelNamespace context is missing. You might have forgotten to use `@model_wrapper`.')
@staticmethod
def next_label(key: str = _DEFAULT_MODEL_NAMESPACE) -> str:
"""Get the next label for API calls, with automatic numbering."""
try:
current_context = ContextStack.top(key)
except NoContextError:
# fallback to use "default" namespace
# it won't be registered
warnings.warn('ModelNamespace is missing. You might have forgotten to use `@model_wrapper`. '
'Some features might not work. This will be an error in future releases.', RuntimeWarning)
current_context = ModelNamespace('default')
next_uid = uid(current_context._simple_name())
return current_context._simple_name() + '_' + str(next_uid)
def get_current_context(key: str) -> Any:
return ContextStack.top(key)
# map variables to prefix in the state dict
# e.g., {'upsample': 'mynet.module.deconv2.upsample_layer'}
STATE_DICT_PY_MAPPING = '_mapping_'
# map variables to `prefix`.`value` in the state dict
# e.g., {'upsample': 'choice3.upsample_layer'},
# which actually means {'upsample': 'mynet.module.choice3.upsample_layer'},
# and 'upsample' is also in `mynet.module`.
STATE_DICT_PY_MAPPING_PARTIAL = '_mapping_partial_'
@contextmanager
def original_state_dict_hooks(model: Any):
"""
Use this patch if you want to save/load state dict in the original state dict hierarchy.
For example, when you already have a state dict for the base model / search space (which often
happens when you have trained a supernet with one-shot strategies), the state dict isn't organized
in the same way as when a sub-model is sampled from the search space. This patch will help
the modules in the sub-model find the corresponding module in the base model.
The code looks like,
.. code-block:: python
with original_state_dict_hooks(model):
model.load_state_dict(state_dict_from_supernet, strict=False) # supernet has extra keys
Or vice-versa,
.. code-block:: python
with original_state_dict_hooks(model):
supernet_style_state_dict = model.state_dict()
"""
import torch.utils.hooks
import torch.nn as nn
assert isinstance(model, nn.Module), 'PyTorch is the only supported framework for now.'
# the following are written for pytorch only
# first get the full mapping
full_mapping = {}
def full_mapping_in_module(src_prefix, tar_prefix, module):
if hasattr(module, STATE_DICT_PY_MAPPING):
# only values are complete
local_map = getattr(module, STATE_DICT_PY_MAPPING)
elif hasattr(module, STATE_DICT_PY_MAPPING_PARTIAL):
# keys and values are both incomplete
local_map = getattr(module, STATE_DICT_PY_MAPPING_PARTIAL)
local_map = {k: tar_prefix + v for k, v in local_map.items()}
else:
# no mapping
local_map = {}
if '__self__' in local_map:
# special case, overwrite prefix
tar_prefix = local_map['__self__'] + '.'
for key, value in local_map.items():
if key != '' and key not in module._modules: # not a sub-module, probably a parameter
full_mapping[src_prefix + key] = value
if src_prefix != tar_prefix: # To deal with leaf nodes.
for name, value in itertools.chain(module._parameters.items(), module._buffers.items()): # direct children
if value is None or name in module._non_persistent_buffers_set:
# it won't appear in state dict
continue
if (src_prefix + name) not in full_mapping:
full_mapping[src_prefix + name] = tar_prefix + name
for name, child in module.named_children():
# sub-modules
full_mapping_in_module(
src_prefix + name + '.',
local_map.get(name, tar_prefix + name) + '.', # if mapping doesn't exist, respect the prefix
child
)
full_mapping_in_module('', '', model)
def load_state_dict_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
reverse_mapping = defaultdict(list)
for src, tar in full_mapping.items():
reverse_mapping[tar].append(src)
transf_state_dict = {}
for src, tar_keys in reverse_mapping.items():
if src in state_dict:
value = state_dict.pop(src)
for tar in tar_keys:
transf_state_dict[tar] = value
else:
missing_keys.append(src)
state_dict.update(transf_state_dict)
def state_dict_hook(module, destination, prefix, local_metadata):
result = {}
for src, tar in full_mapping.items():
if src in destination:
result[tar] = destination.pop(src)
else:
raise KeyError(f'"{src}" not in state dict, but found in mapping.')
destination.update(result)
hooks: List[torch.utils.hooks.RemovableHandle] = []
try:
hooks.append(model._register_load_state_dict_pre_hook(load_state_dict_hook))
hooks.append(model._register_state_dict_hook(state_dict_hook))
yield
finally:
for hook in hooks:
hook.remove()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import inspect
import os
import warnings
from typing import Any, TypeVar, Type
from nni.common.serializer import is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from .misc import ModelNamespace
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
'is_basic_unit', 'is_model_wrapped']
T = TypeVar('T')
def get_init_parameters_or_fail(obj: Any):
if is_traceable(obj):
return obj.trace_kwargs
raise ValueError(f'Object {obj} needs to be serializable but `trace_kwargs` is not available. '
'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
'If it is a customized module, please to decorate it with @basic_unit. '
'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
'try to use @nni.trace.')
def serialize(cls, *args, **kwargs):
"""
To create an serializable instance inline without decorator. For example,
.. code-block:: python
self.op = serialize(MyCustomOp, hidden_units=128)
"""
warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
'Try to use nni.trace, e.g., nni.trace(torch.optim.Adam)(learning_rate=1e-4) instead.',
category=DeprecationWarning)
return trace(cls)(*args, **kwargs)
def serialize_cls(cls):
"""
To create an serializable class.
"""
warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
'Try to use nni.trace instead.', category=DeprecationWarning)
return trace(cls)
def basic_unit(cls: T, basic_unit_tag: bool = True) -> T:
"""
To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
``basic_unit_tag`` is true by default. If set to false, it will not be explicitly mark as a basic unit, and
graph parser will continue to parse. Currently, this is to handle a special case in ``nn.Sequential``.
Although ``basic_unit`` calls ``trace`` in its implementation, it is not for serialization. Rather, it is meant
to capture the initialization arguments for mutation. Also, graph execution engine will stop digging into the inner
modules when it reaches a module that is decorated with ``basic_unit``.
.. code-block:: python
@basic_unit
class PrimitiveOp(nn.Module):
...
"""
# Internal flag. See nni.trace
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls
if _check_wrapped(cls, 'basic_unit'):
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' # type: ignore
cls = trace(cls)
cls._nni_basic_unit = basic_unit_tag # type: ignore
_torchscript_patch(cls)
return cls
def model_wrapper(cls: T) -> T:
"""
Wrap the base model (search space). For example,
.. code-block:: python
@model_wrapper
class MyModel(nn.Module):
...
The wrapper serves two purposes:
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in namespace so that the auto label counting in each model stably starts from zero.
Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
But in future, we might enforce ``@model_wrapper`` to be required for base model.
"""
# Internal flag. See nni.trace
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls
if _check_wrapped(cls, 'model_wrapper'):
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module) # type: ignore
# subclass can still use trace info
wrapper = trace(cls, inheritable=True)
class reset_wrapper(wrapper):
def __init__(self, *args, **kwargs):
self._model_namespace = ModelNamespace()
with self._model_namespace:
super().__init__(*args, **kwargs)
_copy_class_wrapper_attributes(wrapper, reset_wrapper)
reset_wrapper.__wrapped__ = getattr(wrapper, '__wrapped__', wrapper)
reset_wrapper._nni_model_wrapper = True
reset_wrapper._traced = True
_torchscript_patch(cls)
return reset_wrapper
def is_basic_unit(cls_or_instance) -> bool:
if not inspect.isclass(cls_or_instance):
cls_or_instance = cls_or_instance.__class__
return getattr(cls_or_instance, '_nni_basic_unit', False)
def is_model_wrapped(cls_or_instance) -> bool:
if not inspect.isclass(cls_or_instance):
cls_or_instance = cls_or_instance.__class__
return getattr(cls_or_instance, '_nni_model_wrapper', False)
def _check_wrapped(cls: Type, rewrap: str) -> bool:
wrapped = None
if is_model_wrapped(cls):
wrapped = 'model_wrapper'
elif is_basic_unit(cls):
wrapped = 'basic_unit'
elif is_wrapped_with_trace(cls):
wrapped = 'nni.trace'
if wrapped:
if wrapped != rewrap:
raise TypeError(f'{cls} is already wrapped with {wrapped}. Cannot rewrap with {rewrap}.')
return True
return False
def _torchscript_patch(cls) -> None:
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import torch
if hasattr(cls, '_get_nni_attr'): # could not exist on non-linux
cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
if hasattr(cls, 'trace_symbol'):
# these must all exist or all non-exist
try:
cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
cls.trace_copy = torch.jit.ignore(cls.trace_copy)
except AttributeError as e:
if 'property' in str(e):
raise RuntimeError('Trace on PyTorch module failed. Your PyTorch version might be outdated. '
'Please try to upgrade PyTorch.')
raise
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['model_to_pytorch_script']
# pylint: disable=wildcard-import,unused-wildcard-import
import logging
import re
from typing import Dict, List, Tuple, Any, cast
from nni.retiarii.operation import PyTorchOperation
from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.retiarii.utils import STATE_DICT_PY_MAPPING
from nni.common.device import Device, GPUDevice
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
_logger = logging.getLogger(__name__)
def model_to_pytorch_script(model: Model, placement=None) -> str:
graphs = []
total_pkgs = set()
for name, cell in model.graphs.items():
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.debug('sorted_incoming_edges: %s', str(edges))
if not edges:
return []
_logger.debug('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: cast(int, edge.tail_slot)))
if [edge.tail_slot for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node, graph_name: str) -> Tuple[List[str], List[Any]]:
"""
Format the inputs of a given node.
Inputs will be formatted with ``_format_variable_name``
Parameters
----------
node : Node
a graph node, get and format its inputs
graph_name : str
subgraph name, to format variable names
Returns
-------
list
the list of input names
list
the list of input values, if an input is simple type, record its value,
otherwise the value is None
"""
edges = _sorted_incoming_edges(node)
inputs = []
inputs_value = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int)
if edge.head.operation.io_names is not None:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs.append(_format_variable_name(edge.head.operation.io_names[edge.head_slot], graph_name))
else:
# when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot))
inputs_value.append(None)
else:
if edge.head_slot is None:
# when the input comes from a single-output operator
inputs.append(_format_variable_name(edge.head.name, graph_name))
if edge.head.operation.type in ('prim::Constant', 'prim::GetAttr') and \
'value' in edge.head.operation.parameters:
inputs_value.append(edge.head.operation.parameters['value'])
else:
inputs_value.append(None)
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(_format_variable_name(edge.head.name, graph_name), edge.head_slot))
inputs_value.append(None)
return inputs, inputs_value
def _format_variable_name(name: str, graph_name: str) -> str:
"""
1. replace invalid characters in node name
2. variables name (full name space) is too long, shorten the name by removing the prefix ```graph_name```
"""
name = name[len(graph_name):] if name.startswith(graph_name) else name
name = name.replace('/', '__')
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
name = re.sub(r'\W|^(?=\d)','_', name)
if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
# name can't start with double underscore
# it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
# but it's actually very common in our generated code
name = name[1:]
elif name.startswith('_'):
# to avoid conflicts between '_' and '__'
name = 'i' + name
return name
def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
'''
Since CUDA_VISIBLE_DEVICES will be set to the list of real GPU ID,
we need to remap the GPU ID when generating code to match them correctly.
For example, when CUDA_VISIBLE_DEVICES="0,3", we need to use "cuda:0", "cuda:1" in the generated code.
'''
unique_devices = sorted(list(set([e for e in placement.values() if isinstance(e, GPUDevice)])))
node_gpu_cnt = {}
cuda_remapped_id = {}
for d in unique_devices:
if d.node_id not in node_gpu_cnt:
node_gpu_cnt[d.node_id] = 0
node_gpu_cnt[d.node_id] += 1
cuda_remapped_id[d] = node_gpu_cnt[d.node_id] - 1
return cuda_remapped_id
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> Tuple[set, str]:
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
node_python_mappings = {}
cuda_remapped_id = None
if placement:
cuda_remapped_id = generate_cuda_mapping(placement)
for node in nodes:
if node.operation:
if placement and isinstance(node.operation, ToDevice):
cuda_remapped_id = cast(dict, cuda_remapped_id)
node.operation.override_device_repr("cuda:%d" % cuda_remapped_id[node.operation.device])
if node.operation.type == 'shared':
continue
pkg_name = cast(PyTorchOperation, node.operation).get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
py_variable_name = _format_variable_name(node.name, graph_name)
node_code = node.operation.to_init_code(py_variable_name)
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
if isinstance(placement[node], GPUDevice):
assert cuda_remapped_id is not None
device_repr = "cuda:%d" % cuda_remapped_id[placement[node]]
else:
device_repr = placement[node].device_repr()
node_codes.append(f"{node_code}.to('{device_repr}')")
else:
node_codes.append(node_code)
# Map to module hierarchies in original search space python code
node_python_mappings[py_variable_name] = node.python_name
node_codes.append(f'self.{STATE_DICT_PY_MAPPING} = {node_python_mappings}')
if graph.input_node.operation.io_names is None:
input_code = '*_inputs'
else:
for name in graph.input_node.operation.io_names:
assert not name.startswith(graph_name)
input_code = ', '.join(graph.input_node.operation.io_names)
edge_codes = []
sorted_nodes = graph.topo_sort()
for node in sorted_nodes:
if node.operation:
inputs, inputs_value = _format_inputs(node, graph_name)
node_name = _format_variable_name(node.name, graph_name)
submodule_name = node_name
if node.operation.type == 'shared':
submodule_name = _format_variable_name(node.operation.parameters['reference'], graph_name)
edge_codes.append(node.operation.to_forward_code(submodule_name, node_name, inputs, inputs_value))
output_names, _ = _format_inputs(graph.output_node, graph_name)
if not output_names:
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
output_code = ', '.join(output_names)
linebreak = '\n '
return import_pkgs, _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
# TODO: handle imports
_PyTorchScriptTemplate = '''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import nni.retiarii.nn.pytorch
{}
{}
'''
_PyTorchModelTemplate = '''
class {graph_name}(nn.Module):
def __init__(self):
super().__init__()
{nodes}
def forward(self, {inputs}):
{edges}
return {outputs}
'''
from nni.nas.execution.pytorch.codegen import *
......@@ -107,4 +107,4 @@ class {graph_name}(K.Model):
def call(self, {inputs}):
{edges}
return {outputs}
'''
\ No newline at end of file
'''
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import re
# pylint: disable=wildcard-import,unused-wildcard-import
import torch
from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, Placeholder, LayerChoice
from ..operation import Cell, Operation
from ..serializer import get_init_parameters_or_fail
from ..utils import get_importable_name
from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import (
_convert_name, build_full_name, _without_shape_info,
_extract_info_from_trace_node, get_full_name_by_scope_name,
is_layerchoice_node, match_node, build_cand_name,
build_python_name
)
class GraphConverter:
def __init__(self):
self.global_seq = 0
self.global_graph_id = 0
def _add_edge_handle_source_node(self, _input, graph_inputs, ir_graph, output_remap, node_index):
if _input in output_remap:
assert output_remap[_input].kind() == 'aten::append'
predecessor_node = output_remap[_input]
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
src_node_idx = None
src_node = node_index[predecessor_node]
assert isinstance(src_node, Node)
elif _input in graph_inputs:
idx = graph_inputs.index(_input)
src_node = ir_graph.input_node
src_node_idx = idx
else:
predecessor_node = _input.node()
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs = [_output for _output in predecessor_node.outputs()]
if len(predecessor_outputs) == 1:
idx = None
else:
idx = predecessor_outputs.index(_input)
ir_predecessor_node = node_index[predecessor_node]
src_node_idx = idx
assert isinstance(ir_predecessor_node, Node)
src_node = ir_predecessor_node
return src_node, src_node_idx
def _add_edge(self, ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
"""
Parameters
----------
ir_graph : Graph
node : torch._C.Node
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
output_remap : Dict
ignore_first : bool
if it is true, skip the first input
"""
is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1
new_node_input_idx = 0
for _input in node.inputs():
if ignore_first:
ignore_first = False
continue
# handle source node
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
# handle destination node
dst_node = new_node
if is_single_input:
dst_node_idx = None
else:
dst_node_idx = new_node_input_idx
# create edge
ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx))
new_node_input_idx += 1
def create_prim_constant_node(self, ir_graph, node, module_name):
# NOTE: compare with string not type, because the type is defined in pytorch C code.
# `.kind()` can also be used here
if node.outputsAt(0).type().str() == 'None':
attrs = {'type': 'None'}
else:
attrs = {'type': node.outputsAt(0).type().str(), 'value': node.outputsAt(0).toIValue()}
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, self.global_seq),
node.kind(), attrs)
return new_node
def handle_prim_attr_node(self, node, module):
assert node.hasAttribute('name')
value = None
if node.inputsAt(0).debugName() == 'self':
_val = getattr(module, node.s('name'))
# TODO: serialize complex data type, and output proper error message
if isinstance(_val, (int, float, str, bool)):
value = _val
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName(), 'value': value}
return node.kind(), attrs
def _remove_mangle(self, module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def remove_unconnected_nodes(self, ir_graph, targeted_type=None):
"""
Parameters
----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout = set()
for edge in ir_graph.edges:
if edge.head.id not in node_fanout:
node_fanout.add(edge.head.id)
to_removes = []
for hidden_node in ir_graph.hidden_nodes:
if hidden_node.id not in node_fanout:
assert isinstance(hidden_node, Node)
if targeted_type is None:
to_removes.append(hidden_node)
elif hidden_node.operation.type == targeted_type:
to_removes.append(hidden_node)
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(self, script_module, sm_graph,
module, module_name, module_python_name,
ir_model, ir_graph,
shared_module_index=None):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
shared_module_index : dict
it is used for knowing which module has been created an ir node,
if created and invoked again, then the new ir node can simply reference that ir node.
this way we can identify shared modules (i.e., one module invoked multiple times in `forward` function)
Returns
-------
dict
the mapping from graph node to our graph ir node
"""
# handle inputs
graph_inputs = []
for _input in sm_graph.inputs():
if _input.debugName() == 'self':
assert _input.unique() == 0
continue
graph_inputs.append(_input)
# TODO: add scope name
ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node
if shared_module_index is None:
shared_module_index = {}
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap = {}
# ===================handle control flow: if===================
def handle_if_condition(cond_tensor):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
generate the expression using recursive calls
NOTE: do not support dynamic graph
"""
def _generate_expr(tensor):
if tensor.node().kind() == 'prim::GetAttr':
return f'({getattr(module, tensor.node().s("name"))})'
elif tensor.node().kind() == 'aten::__getitem__':
t = _generate_expr(tensor.node().inputsAt(0))
idx = _generate_expr(tensor.node().inputsAt(1))
return f'({t}[{idx}])'
elif tensor.node().kind() == 'prim::Constant':
return f'{tensor.toIValue()}'
elif tensor.node().kind() == 'aten::eq':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} == {right})'
elif tensor.node().kind() == 'aten::le':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} <= {right})'
elif tensor.node().kind() == 'aten::ge':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} >= {right})'
elif tensor.node().kind() == 'aten::__not__':
value = _generate_expr(tensor.node().inputsAt(0))
return f'(not {value})'
elif tensor.node().kind() == 'aten::Bool':
value = _generate_expr(tensor.node().inputsAt(0))
return f'bool({value})'
elif tensor.node().kind() == 'aten::__is__':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} is {right})'
elif tensor.node().kind() == 'aten::__isnot__':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} is not {right})'
elif tensor.node().kind() == 'aten::ne':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} != {right})'
elif tensor.node().kind() == 'aten::gt':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} > {right})'
elif tensor.node().kind() == 'aten::lt':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} < {right})'
elif tensor.node().kind() == 'prim::If':
raise RuntimeError('Have not supported `if A and/or B`, please use two `if` statements instead.')
elif tensor.node().kind() == 'aten::abs':
value = _generate_expr(tensor.node().inputsAt(0))
return f'(torch.abs({value}))'
elif tensor.node().kind() == 'aten::sum':
value = _generate_expr(tensor.node().inputsAt(0))
return f'(torch.sum({value}))'
elif tensor.node().kind() == 'aten::item':
value = _generate_expr(tensor.node().inputsAt(0))
return f'({value}.item())'
else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition, '
'you are suggested to decorate the corresponding class with "@basic_unit".')
expr = _generate_expr(cond_tensor)
return eval(expr)
def handle_if_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
# only deal with input of prim::If is constant or attribute for now
# will support constant expression in future
inputs = [i for i in node.inputs()]
assert len(inputs) == 1
cond = handle_if_condition(inputs[0])
chosen_block = 0 if cond else 1
blocks = [block for block in node.blocks()]
assert len(blocks) == 2
last_block_node = None
for node in blocks[chosen_block].nodes():
last_block_node = handle_single_node(node)
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, 'noop_identity', self.global_seq), 'noop_identity')
self._add_edge(ir_graph, blocks[chosen_block].returnNode(), graph_inputs, node_index, new_node, output_remap)
last_block_node = new_node
return last_block_node
# ===================handle function call===================
def handle_function_callmethod(node):
# get and handle the first input, which should be an nn.Module
assert node.hasAttribute('name')
# NOTE: "forward__0" is hacky, LSTM instance is parsed to call forward__0 in torchscript
if node.s('name') in ['forward', 'forward__0']:
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str = self._remove_mangle(node.inputsAt(0).type().str())
submodule = node.inputsAt(0).node()
assert submodule.kind() == 'prim::GetAttr'
assert submodule.hasAttribute('name')
submodule_name = submodule.s('name')
if submodule.inputsAt(0).debugName() == 'self':
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_python_name = build_python_name(module_python_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = self._convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, submodule_python_name,
ir_model)
else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if submodule.inputsAt(0).type().name() == 'ModuleList':
# handle ModuleList
predecessor = submodule.inputsAt(0).node()
module_name_space = [submodule_name]
while predecessor.inputsAt(0).debugName() != 'self':
# this is for dealing with nested ModuleList. below is an example
# %3 : __torch__.torch.nn.modules.container.___torch_mangle_0.ModuleList = prim::GetAttr[name="ops"](%self)
# %5 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="0"](%3)
# %7 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="1"](%3)
# %9 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="2"](%3)
# %11 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="3"](%3)
# %14 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="0"](%5)
# %16 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="1"](%5)
# %state.2 : Tensor = prim::CallMethod[name="forward"](%14, %x.1) # modulelist.py:18:24
# %state.4 : Tensor = prim::CallMethod[name="forward"](%16, %state.2) # modulelist.py:18:24
assert predecessor.kind() == 'prim::GetAttr'
module_name_space.append(predecessor.s('name'))
predecessor = predecessor.inputsAt(0).node()
assert predecessor.kind() == 'prim::GetAttr'
assert predecessor.hasAttribute('name')
module_name_space.append(predecessor.s('name'))
submodule_full_name = build_full_name(module_name, list(reversed(module_name_space)))
submodule_python_name = build_python_name(module_python_name, list(reversed(module_name_space)))
submodule_obj = module
script_submodule = script_module
for each_name in list(reversed(module_name_space)):
submodule_obj = getattr(submodule_obj, each_name)
script_submodule = script_submodule._modules[each_name]
subgraph, sub_m_attrs = self._convert_module(script_submodule, submodule_obj, submodule_full_name,
submodule_python_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
if submodule_full_name in shared_module_index:
# this module is invoked more than once, the ir node has already been created
# create a reference node for it.
# example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}}
self.global_seq += 1
shared_node_name = build_full_name(submodule_full_name, '', self.global_seq)
shared_node_python_name = build_python_name(submodule_python_name, self.global_seq)
shared_type_operation = Operation.new('shared', {'reference': submodule_full_name})
subcell = ir_graph.add_node(shared_node_name, shared_type_operation)
subcell.python_name = shared_node_python_name
else:
# this module is processed for the first time, build cell for it
if subgraph is None:
# if we do not parse this module's graph, we create Node for this module
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
subcell.python_name = submodule_python_name
if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, InputChoice):
subcell.update_label(sub_m_attrs['label'])
else:
# Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell)
subcell.python_name = submodule_python_name
shared_module_index[submodule_full_name] = subcell
node_index[node] = subcell
# connect the cell into graph
self._add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
else:
# handle normal member function
assert hasattr(script_module, node.s('name'))
# TODO: support non member functions
assert node.inputsAt(0).debugName() == 'self'
script_method = getattr(script_module, node.s('name')) # <class 'torch._C.ScriptMethod'>
# step #1: generate graph ir for this method
method_ir_graph = Graph(model=ir_model, graph_id=-100, name='temp_graph', _internal=True)
self.handle_graph_nodes(script_module, script_method.graph, module,
module_name, module_python_name, ir_model, method_ir_graph, shared_module_index)
self.refine_graph(method_ir_graph)
# step #2: merge this graph to its module graph
for h_node in method_ir_graph.hidden_nodes:
h_node.graph = ir_graph
ir_graph.hidden_nodes.append(h_node)
for edge in method_ir_graph.edges:
edge.graph = ir_graph
if edge.head == method_ir_graph.input_node:
# this is a member method, 'self' is the first argument, thus +1
assert edge.head_slot is not None
_input = node.inputsAt(edge.head_slot + 1)
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
edge.head = src_node
edge.head_slot = src_node_idx
if edge.tail == method_ir_graph.output_node:
# since the following nodes have not been created, skip this edge
# edge.head is the output node of this method
# TODO: check whether there could be multiple output nodes???
node_index[node] = edge.head
continue
ir_graph.edges.append(edge)
# ===================handle each single node===================
def handle_single_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
if node.kind() == 'prim::CallMethod':
handle_function_callmethod(node)
elif node.kind() == 'prim::CallFunction':
func_type_str = self._remove_mangle(node.inputsAt(0).type().str())
func = node.inputsAt(0).node()
assert func.kind() == 'prim::Constant'
assert func.hasAttribute('name')
func_name = func.s('name')
# create node for func
self.global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, self.global_seq),
'{}.{}'.format(func_type_str, func_name))
func_python_name = build_python_name(module_python_name, func_name)
func_node.python_name = func_python_name
node_index[node] = func_node
self._add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant':
new_node = self.create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node
elif node.kind() in ['prim::ListConstruct', 'prim::ListUnpack', 'prim::TupleConstruct', 'prim::TupleUnpack']:
self.global_seq += 1
prim_op_name = node.kind().split('::')[-1]
new_node = ir_graph.add_node(build_full_name(module_name, prim_op_name, self.global_seq), node.kind())
node_index[node] = new_node
self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = self.handle_prim_attr_node(node, module)
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, self.global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
# last_block_node is None means no node in the branch block
node_index[node] = last_block_node
elif node.kind() == 'prim::Loop':
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise RuntimeError('Loop has not been supported yet!')
elif node.kind().startswith('prim::'):
self.global_seq += 1
prim_op_name = node.kind().replace('::', '__')
prim_node = ir_graph.add_node(build_full_name(module_name, prim_op_name, self.global_seq), node.kind())
node_index[node] = prim_node
self._add_edge(ir_graph, node, graph_inputs, node_index, prim_node, output_remap)
elif node.kind() == 'aten::append':
self.global_seq += 1
aten_op_name = node.kind().replace('::', '__')
aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
self.global_seq += 1
aten_op_name = node.kind().replace('::', '__')
aten_op_python_name = node.kind().replace('aten::', '')
aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
aten_python_name = build_python_name(module_python_name, aten_op_python_name)
aten_node.python_name = aten_python_name
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
else:
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
return node_index[node]
for node in sm_graph.nodes():
handle_single_node(node)
if node_index != {}:
for _output in sm_graph.outputs():
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
else:
# here is an example that the ir_graph and node_index is empty
# graph(%self : __torch__.torchmodels.googlenet.GoogLeNet,
# %x.1 : Tensor): return (%x.1)
# add an edge from head to tail to handle this situation
ir_graph.add_edge(head=(ir_graph.input_node, 0), tail=(ir_graph.output_node, None))
def merge_aten_slices(self, ir_graph):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes = []
has_slice_node = False
for node in ir_graph.hidden_nodes:
if node.operation.type == 'aten::slice':
has_slice_node = True
for pred in node.predecessors:
if pred.operation.type not in ['aten::slice', 'prim::Constant']:
head_slice_nodes.append(node)
break
if has_slice_node:
assert head_slice_nodes
for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
if len(head_node.incoming_edges) == 4:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
for edge in head_node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(head_node)
break
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
slot += 5
node = head_node
while len(node.successors) == 1 and node.successors[0].operation.type == 'aten::slice':
suc_node = node.successors[0]
assert len(suc_node.incoming_edges) == 5
for edge in suc_node.incoming_edges:
if edge.tail_slot == 0:
edge.remove()
else:
edge.tail = new_slice_node
edge.tail_slot = slot + edge.tail_slot - 1
slot += 4
ir_graph.hidden_nodes.remove(node)
node = suc_node
for edge in node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node)
def refine_graph(self, ir_graph):
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
self.remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
self.remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
self.merge_aten_slices(ir_graph)
def _handle_inputchoice(self, module):
return {
'n_candidates': module.n_candidates,
'n_chosen': module.n_chosen,
'reduction': module.reduction,
'label': module.label
}
def _handle_valuechoice(self, module):
return {
'candidates': module.candidates,
'label': module.label,
}
def _convert_module(self, script_module, module, module_name, module_python_name, ir_model):
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
m_attrs = None
if original_type_name == OpTypeName.LayerChoice:
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
graph.python_name = module_python_name
candidate_name_list = []
for cand_name in module.names:
cand = module[cand_name]
script_cand = script_module._modules[cand_name]
cand_full_name = build_cand_name(cand_name, module.label)
cand_python_name = build_python_name(module_python_name, cand_name)
candidate_name_list.append(cand_full_name)
subgraph, attrs = self._convert_module(script_cand, cand, cand_full_name, cand_python_name, ir_model)
if subgraph is not None:
cand_node = graph.add_node(subgraph.name, Cell(cell_name=subgraph.name, parameters=attrs))
cand_node.python_name = cand_python_name
else:
cand_type = '__torch__.' + get_importable_name(cand.__class__)
cand_node = graph.add_node(cand_full_name, cand_type, attrs)
cand_node.python_name = cand_python_name
graph._register()
return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list}
elif original_type_name == OpTypeName.InputChoice:
m_attrs = self._handle_inputchoice(module)
elif original_type_name == OpTypeName.ValueChoice:
m_attrs = self._handle_valuechoice(module)
elif original_type_name == OpTypeName.Placeholder:
m_attrs = get_init_parameters_or_fail(module)
elif module.__class__.__module__.startswith('torch.nn') and \
original_type_name in torch.nn.__dict__ and \
original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module)
elif getattr(module, '_nni_basic_unit', False):
# this module is marked as serialize, won't continue to parse
m_attrs = get_init_parameters_or_fail(module)
if m_attrs is not None:
return None, m_attrs
# handle TorchScript graph
sm_graph = script_module.graph
self.global_graph_id += 1
ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True)
ir_graph.python_name = module_python_name
# handle graph nodes
self.handle_graph_nodes(script_module, sm_graph, module,
module_name, module_python_name, ir_model, ir_graph)
self.refine_graph(ir_graph)
ir_graph._register()
# add mutation signal for special modules
if original_type_name == OpTypeName.Repeat:
attrs = {
'mutation': 'repeat',
'label': module.label,
'depth': module.depth_choice,
'max_depth': module.max_depth,
'min_depth': module.min_depth,
}
return ir_graph, attrs
return ir_graph, {}
def convert_module(self, script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
return self._convert_module(script_module, module, module_name, None, ir_model)
class GraphConverterWithShape(GraphConverter):
"""
Convert a pytorch model to nni ir along with input/output shape info.
Based ir acquired through ``torch.jit.script``
and shape info acquired through ``torch.jit.trace``.
.. warning::
Known issues:
1. ``InputChoice`` and ``ValueChoice`` not supported yet.
2. Currently random inputs are fed while tracing layerchoice.
If forward path of candidates depends on input data, then wrong path will be traced.
This will result in incomplete shape info.
"""
def convert_module(self, script_module, module, module_name, ir_model, dummy_input):
module.eval()
ir_graph, attrs = self._convert_module(script_module, module, module_name, None, ir_model)
self.remove_dummy_nodes(ir_model)
self._initialize_parameters(ir_model)
self._trace_module(module, module_name, ir_model, dummy_input)
return ir_graph, attrs
def _initialize_parameters(self, ir_model: 'Model'):
for ir_node in ir_model.get_nodes():
if ir_node.operation.parameters is None:
ir_node.operation.parameters = {}
ir_node.operation.attributes.setdefault('input_shape', [])
ir_node.operation.attributes.setdefault('output_shape', [])
def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input):
# First, trace the whole graph
tm_graph = self._trace(module, dummy_input)
for node in tm_graph.nodes():
shape_parameters, parameters = _extract_info_from_trace_node(node)
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
ir_node = match_node(ir_model, node, module_name)
if ir_node is not None:
ir_node.operation.attributes.update(shape_parameters)
if parameters:
ir_node.operation.parameters.update(parameters)
self.propagate_shape(ir_model)
# trace each layerchoice
for name, submodule in module.named_modules():
# TODO: support InputChoice and ValueChoice
if isinstance(submodule, LayerChoice):
full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name)
lc_node = ir_model.get_node_by_name(full_name)
assert lc_node is not None, f'Cannot find a node with name {full_name}'
for cand_name in submodule.names:
cand = submodule[cand_name]
cand_name = build_cand_name(cand_name, submodule.label)
# TODO: Feed the exact input tensor if user provides input,
# in case the path changes according to input data.
lc_inputs = [torch.randn(shape) for shape in lc_node.operation.attributes['input_shape']]
self._trace_module(cand, cand_name, ir_model, lc_inputs)
def propagate_shape(self, ir_model: 'Model'):
def propagate_shape_for_graph(graph: 'Graph'):
if graph == ir_model.root_graph:
return
graph_node = ir_model.get_node_by_name(graph.name)
assert graph_node is not None, f'Cannot find a node with name {graph.name}'
if not _without_shape_info(graph_node):
return
if is_layerchoice_node(graph_node):
cand_name = graph_node.operation.parameters['candidates'][0]
cand_node = ir_model.get_node_by_name(cand_name)
assert cand_node is not None, f'Cannot find a node with name {cand_name}'
if _without_shape_info(cand_node):
propagate_shape_for_graph(ir_model.graphs[cand_name])
graph_node.operation.attributes['input_shape'] = cand_node.operation.attributes['input_shape']
graph_node.operation.attributes['output_shape'] = cand_node.operation.attributes['output_shape']
else:
input_shape = [[]] * len(graph.input_node.operation.io_names or [])
output_shape = [[]] * len(graph.output_node.operation.io_names or [])
for edge in graph.input_node.outgoing_edges:
node = edge.tail
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.attributes['input_shape']:
input_shape[edge.head_slot or 0] = node.operation.attributes['input_shape'][edge.tail_slot or 0]
graph_node.operation.attributes['input_shape'] = input_shape
for edge in graph.output_node.incoming_edges:
node = edge.head
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.attributes['output_shape']:
output_shape[edge.tail_slot or 0] = node.operation.attributes['output_shape'][edge.head_slot or 0]
graph_node.operation.attributes['output_shape'] = output_shape
propagate_shape_for_graph(graph_node.graph)
# propagate from node to graph
for node in ir_model.get_nodes():
propagate_shape_for_graph(node.graph)
def _trace(self, module, dummy_input):
traced_module = torch.jit.trace(module, dummy_input)
torch._C._jit_pass_inline(traced_module.graph)
return traced_module.graph
def remove_dummy_nodes(self, ir_model: 'Model'):
# remove identity nodes
for node in ir_model.get_nodes_by_type('noop_identity'):
graph = node.graph
for in_edge in node.incoming_edges:
for out_edge in node.outgoing_edges:
if in_edge.tail_slot == out_edge.head_slot:
graph.add_edge(head=(in_edge.head, in_edge.head_slot), tail=(out_edge.tail, out_edge.tail_slot))
graph.del_edge(in_edge)
graph.del_edge(out_edge)
break
node.remove()
def convert_to_graph(script_module, module, converter=None, **kwargs):
"""
Convert module to our graph ir, i.e., build a :class:`Model` type
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module obtained with torch.jit.script
module : nn.Module
the targeted module instance
converter : `TorchConverter`
default `GraphConverter` is used
kwargs:
will be passed to `converter.convert_module()`
Returns
-------
Model
the constructed IR model
"""
model = Model(_internal=True)
module_name = '_model'
if converter is None:
converter = GraphConverter()
converter.convert_module(script_module, module, module_name, model, **kwargs)
return model
from nni.nas.execution.pytorch.converter.graph_gen import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from enum import Enum
# pylint: disable=wildcard-import,unused-wildcard-import
# except the special case which can not treat as a basic module from pytorch
MODULE_EXCEPT_LIST = ['Sequential']
class OpTypeName(str, Enum):
"""
op type to its type name str
"""
Attr = 'Attr'
Constant = 'Constant'
LayerChoice = 'LayerChoice'
InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
Repeat = 'Repeat'
Cell = 'Cell'
from nni.nas.execution.pytorch.converter.op_types import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional
from typing_extensions import TypeGuard
from ..operation import Cell
from ..graph import Model, Graph, Node, Edge
def build_full_name(prefix, name, seq=None):
if isinstance(name, list):
name = '__'.join(name)
if seq is None:
return '{}__{}'.format(prefix, name)
else:
return '{}__{}{}'.format(prefix, name, str(seq))
def build_python_name(prefix, name):
if isinstance(name, list):
name = '.'.join(name)
if prefix:
return '{}.{}'.format(prefix, name)
else: # predix could be None
return name
def build_cand_name(name, label):
return f'layerchoice_{label}_{name}'
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
def _extract_info_from_trace_node(trace_node):
"""
Extract parameters from a trace node.
Parameters
----------
trace_node: torch._C.Value
"""
input_shape = []
output_shape = []
inputs = list(trace_node.inputs())
# cat input tensors are in a strange place
if trace_node.kind() == 'aten::cat':
input_shape = [input.type().sizes() for input in inputs[0].node().inputs()]
else:
for _input in inputs:
input_type = _input.type()
if input_type.kind() == 'TensorType':
shape = input_type.sizes()
if shape:
input_shape.append(shape)
for _output in trace_node.outputs():
output_type = _output.type()
if output_type.kind() == 'TensorType':
shape = output_type.sizes()
if shape:
output_shape.append(shape)
shape_parameters = {
'input_shape': input_shape,
'output_shape': output_shape,
}
if trace_node.kind() == 'aten::cat':
parameters = {'dim': inputs[1].toIValue()}
return shape_parameters, parameters
else:
return shape_parameters, None
def is_layerchoice_node(ir_node: Optional[Node]) -> TypeGuard[Node]:
if ir_node is not None and isinstance(ir_node.operation, Cell) and ir_node.operation.parameters.get('mutation') == 'layerchoice':
return True
else:
return False
def get_full_name_by_scope_name(ir_model: Model, scope_names, prefix=''):
full_name = prefix
for last_scope in range(len(scope_names)):
ir_node = ir_model.get_node_by_name(full_name)
# check if it's layerchoice
if is_layerchoice_node(ir_node):
full_name = f'layerchoice_{ir_node.operation.parameters["label"]}_{scope_names[last_scope]}'
else:
full_name = build_full_name(full_name, scope_names[last_scope])
return full_name
def match_node(ir_model: Model, torch_node, prefix=''):
"""
Match the corresponding node of a torch._C.Value
"""
scope_names = torch_node.scopeName().split('/')[-1].split('.')[1:]
full_name = get_full_name_by_scope_name(ir_model, scope_names, prefix)
# handle the case when node is not nn.Module, but directly used in forward()
# Because name can't be directly matched, so I use a hacky way.
# I match the first unshaped node of that kind
graph = ir_model.graphs.get(full_name)
if graph is not None:
for node in graph.get_nodes_by_type(torch_node.kind()):
if not node.operation.attributes['input_shape']:
return node
return None
else:
return ir_model.get_node_by_name(full_name)
def _without_shape_info(node: Node):
return not node.operation.attributes['input_shape'] and not node.operation.attributes['output_shape']
def flatten_model_graph(ir_model: Model):
"""
Flatten the subgraph into root graph.
"""
def _flatten(graph: Graph):
"""
flatten this graph
"""
model = graph.model
node_to_remove = []
for node in graph.hidden_nodes:
node_graph = model.graphs.get(node.name)
if node_graph is not None:
_flatten(node_graph)
# flatten node graph into this graph
id_to_new_node = {}
for node_graph_node in node_graph.hidden_nodes:
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
new_node.update_label(node_graph_node.label)
new_node._register()
id_to_new_node[new_node.id] = new_node
# reconnect node edges
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for input_node_edge in node_graph.input_node.outgoing_edges:
if input_node_edge.head_slot == in_edge.tail_slot:
graph.add_edge(
head=(in_edge.head, in_edge.head_slot),
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
for output_node_edge in node_graph.output_node.incoming_edges:
if output_node_edge.head_slot == out_edge.tail_slot:
graph.add_edge(
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
tail=(out_edge.tail, out_edge.tail_slot))
for edge in node_graph.edges:
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
continue
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
node_to_remove.append(node)
del model.graphs[node.name]
for node in node_to_remove:
node.remove()
new_ir_model = ir_model.fork()
_flatten(new_ir_model.root_graph)
# remove subgraphs
new_ir_model.graphs = {new_ir_model._root_graph_name: new_ir_model.root_graph}
return new_ir_model
def flatten_model_graph_without_layerchoice(ir_model: Model):
"""
Flatten the subgraph into root graph and jump all layerchoice
"""
def _flatten_without_layerchoice(graph: Graph):
"""
flatten this graph
"""
model = graph.model
node_to_remove = []
for node in graph.hidden_nodes:
if is_layerchoice_node(node):
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
del model.graphs[node.name]
node.remove()
return
node_graph = model.graphs.get(node.name)
if node_graph is not None:
_flatten_without_layerchoice(node_graph)
# flatten node graph into this graph
id_to_new_node = {}
for node_graph_node in node_graph.hidden_nodes:
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
new_node.update_label(node_graph_node.label)
new_node._register()
id_to_new_node[new_node.id] = new_node
# reconnect node edges
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for input_node_edge in node_graph.input_node.outgoing_edges:
if input_node_edge.head_slot == in_edge.tail_slot:
graph.add_edge(
head=(in_edge.head, in_edge.head_slot),
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
for output_node_edge in node_graph.output_node.incoming_edges:
if output_node_edge.head_slot == out_edge.tail_slot:
graph.add_edge(
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
tail=(out_edge.tail, out_edge.tail_slot))
for edge in node_graph.edges:
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
continue
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
node_to_remove.append(node)
del model.graphs[node.name]
for node in node_to_remove:
node.remove()
new_ir_model = ir_model.fork()
_flatten_without_layerchoice(new_ir_model.root_graph)
# remove subgraphs
new_ir_model.graphs = {new_ir_model._root_graph_name: new_ir_model.root_graph}
return new_ir_model
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.converter.utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import graphviz
# pylint: disable=wildcard-import,unused-wildcard-import
def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items():
if name == '_evaluator':
continue
with vgraph.subgraph(name='cluster'+name) as subgraph:
subgraph.attr(color='blue')
cell_node = {}
ioput = {'_inputs': '{}-{}'.format(name, '_'.join(graph['inputs'])),
'_outputs': '{}-{}'.format(name, '_'.join(graph['outputs']))}
subgraph.node(ioput['_inputs'])
subgraph.node(ioput['_outputs'])
for node_name, node_value in graph['nodes'].items():
value = node_value['operation']
if value['type'] == '_cell':
cell_input_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['inputs']))
cell_output_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['outputs']))
cell_node[node_name] = (cell_input_name, cell_output_name)
print('cell: ', node_name, cell_input_name, cell_output_name)
else:
subgraph.node(node_name)
for edge in graph['edges']:
src = edge['head'][0]
if src == '_inputs':
src = ioput['_inputs']
elif src in cell_node:
src = cell_node[src][1]
dst = edge['tail'][0]
if dst == '_outputs':
dst = ioput['_outputs']
elif dst in cell_node:
dst = cell_node[dst][0]
subgraph.edge(src, dst)
def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
convert_to_visualize(graph_ir, vgraph)
vgraph.render()
from nni.nas.execution.pytorch.converter.visualize import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment