Unverified Commit 59cd3982 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Coding style improvements for pylint and flake8 (#3190)

parent 593a275c
import inspect
import logging
from typing import Any, List
import torch
import torch.nn as nn
from typing import (Any, Tuple, List, Optional)
from ...utils import add_record
......@@ -10,7 +11,7 @@ _logger = logging.getLogger(__name__)
__all__ = [
'LayerChoice', 'InputChoice', 'Placeholder',
'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6',
'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink',
......@@ -30,7 +31,7 @@ __all__ = [
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten', 'Hardsigmoid', 'Hardswish'
]
......@@ -57,9 +58,10 @@ class InputChoice(nn.Module):
if n_candidates or choose_from or return_mask:
_logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!')
def forward(self, candidate_inputs: List['Tensor']) -> 'Tensor':
def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor:
# fake return
return torch.tensor(candidate_inputs)
return torch.tensor(candidate_inputs) # pylint: disable=not-callable
class ValueChoice:
"""
......@@ -67,6 +69,7 @@ class ValueChoice:
when instantiating a pytorch module.
TODO: can also be used in training approach
"""
def __init__(self, candidate_values: List[Any]):
self.candidate_values = candidate_values
......@@ -81,6 +84,7 @@ class Placeholder(nn.Module):
def forward(self, x):
return x
class ChosenInputs(nn.Module):
def __init__(self, chosen: int):
super().__init__()
......@@ -92,20 +96,24 @@ class ChosenInputs(nn.Module):
# the following are pytorch modules
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
class Sequential(nn.Sequential):
def __init__(self, *args):
add_record(id(self), {})
super(Sequential, self).__init__(*args)
class ModuleList(nn.ModuleList):
def __init__(self, *args):
add_record(id(self), {})
super(ModuleList, self).__init__(*args)
def wrap_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
......@@ -115,14 +123,15 @@ def wrap_module(original_class):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = args[i]
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
# TODO: support different versions of pytorch
Identity = wrap_module(nn.Identity)
Linear = wrap_module(nn.Linear)
......
......@@ -4,12 +4,14 @@ from . import debug_configs
__all__ = ['Operation', 'Cell']
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
class Operation:
"""
Calculation logic of a graph node.
......@@ -152,6 +154,7 @@ class PyTorchOperation(Operation):
else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
class TensorFlowOperation(Operation):
def _to_class_name(self) -> str:
return 'K.layers.' + self.type
......@@ -191,6 +194,7 @@ class Cell(PyTorchOperation):
framework
No real usage. Exists for compatibility with base class.
"""
def __init__(self, cell_name: str, parameters: Dict[str, Any] = {}):
self.type = '_cell'
self.cell_name = cell_name
......@@ -207,6 +211,7 @@ class _IOPseudoOperation(Operation):
The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking.
"""
def __init__(self, type_name: str, io_names: List = None):
assert type_name.startswith('_')
super(_IOPseudoOperation, self).__init__(type_name, {}, True)
......
from ..operation import TensorFlowOperation
class Conv2D(TensorFlowOperation):
def __init__(self, type_name, parameters, _internal):
if 'padding' not in parameters:
parameters['padding'] = 'same'
super().__init__(type_name, parameters, _internal)
\ No newline at end of file
super().__init__(type_name, parameters, _internal)
from ..operation import PyTorchOperation
class relu(PyTorchOperation):
def to_init_code(self, field):
return ''
......@@ -17,6 +18,7 @@ class Flatten(PyTorchOperation):
assert len(inputs) == 1
return f'{output} = {inputs[0]}.view({inputs[0]}.size(0), -1)'
class ToDevice(PyTorchOperation):
def to_init_code(self, field):
return ''
......
import abc
from typing import List
from ..graph import Model
from ..mutator import Mutator
class BaseStrategy(abc.ABC):
@abc.abstractmethod
def run(self, base_model: 'Model', applied_mutators: List['Mutator']) -> None:
def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None:
pass
import json
import logging
import random
import os
from .. import Model, submit_models, wait_models
from .. import Sampler
from .. import Sampler, submit_models, wait_models
from .strategy import BaseStrategy
from ...algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner
_logger = logging.getLogger(__name__)
class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'):
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
......@@ -37,6 +34,7 @@ class TPESampler(Sampler):
self.index += 1
return chosen
class TPEStrategy(BaseStrategy):
def __init__(self):
self.tpe_sampler = TPESampler()
......@@ -55,7 +53,7 @@ class TPEStrategy(BaseStrategy):
while True:
model = base_model
_logger.info('apply mutators...')
_logger.info('mutators: {}'.format(applied_mutators))
_logger.info('mutators: %s', str(applied_mutators))
self.tpe_sampler.generate_samples(self.model_id)
for mutator in applied_mutators:
_logger.info('mutate model...')
......@@ -66,6 +64,6 @@ class TPEStrategy(BaseStrategy):
wait_models(model)
self.tpe_sampler.receive_result(self.model_id, model.metric)
self.model_id += 1
_logger.info('Strategy says:', model.metric)
except Exception as e:
_logger.info('Strategy says: %s', model.metric)
except Exception:
_logger.error(logging.exception('message'))
import abc
import inspect
from typing import *
from typing import Any
class BaseTrainer(abc.ABC):
......
import abc
from typing import *
from typing import Any, List, Dict, Tuple
import numpy as np
import torch
......@@ -42,6 +41,7 @@ def get_default_transform(dataset: str) -> Any:
# unsupported dataset, return None
return None
@register_trainer()
class PyTorchImageClassificationTrainer(BaseTrainer):
"""
......@@ -94,7 +94,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
self._dataloader = DataLoader(
self._dataset, **(dataloader_kwargs or {}))
def _accuracy(self, input, target):
def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
correct = predict.eq(target.data).cpu().sum().item()
return correct / input.size(0)
......@@ -176,7 +176,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
dataloader = DataLoader(dataset, **(dataloader_kwargs or {}))
self._datasets.append(dataset)
self._dataloaders.append(dataloader)
if m['use_output']:
optimizer_cls = m['optimizer_cls']
optimizer_kwargs = m['optimizer_kwargs']
......@@ -186,7 +186,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
name_prefix = '_'.join(name.split('_')[:2])
if m_header == name_prefix:
one_model_params.append(param)
optimizer = getattr(torch.optim, optimizer_cls)(one_model_params, **(optimizer_kwargs or {}))
self._optimizers.append(optimizer)
......@@ -206,7 +206,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}')
xs.append(x)
ys.append(y)
y_hats = self.multi_model(*xs)
if len(ys) != len(xs):
raise ValueError('len(ys) should be equal to len(xs)')
......@@ -230,13 +230,12 @@ class PyTorchMultiModelTrainer(BaseTrainer):
if self.max_steps and batch_idx >= self.max_steps:
return
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.training_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.training_step_after_model(x, y, y_hat)
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device = None):
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch
if device:
x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device))
......@@ -259,4 +258,4 @@ class PyTorchMultiModelTrainer(BaseTrainer):
def validation_step_after_model(self, x, y, y_hat):
acc = self._accuracy(y_hat, y)
return {'val_acc': acc}
\ No newline at end of file
return {'val_acc': acc}
......@@ -6,7 +6,6 @@ import logging
import torch
import torch.nn as nn
from nni.nas.pytorch.mutables import LayerChoice
from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
......
......@@ -86,8 +86,8 @@ class ReinforceController(nn.Module):
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]),
requires_grad=False) # pylint: disable=not-callable
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
requires_grad=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
......
......@@ -16,7 +16,7 @@ _logger = logging.getLogger(__name__)
def _get_mask(sampled, total):
multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)]
return torch.tensor(multihot, dtype=torch.bool)
return torch.tensor(multihot, dtype=torch.bool) # pylint: disable=not-callable
class PathSamplingLayerChoice(nn.Module):
......@@ -44,9 +44,9 @@ class PathSamplingLayerChoice(nn.Module):
def forward(self, *args, **kwargs):
assert self.sampled is not None, 'At least one path needs to be sampled before fprop.'
if isinstance(self.sampled, list):
return sum([getattr(self, self.op_names[i])(*args, **kwargs) for i in self.sampled])
return sum([getattr(self, self.op_names[i])(*args, **kwargs) for i in self.sampled]) # pylint: disable=not-an-iterable
else:
return getattr(self, self.op_names[self.sampled])(*args, **kwargs)
return getattr(self, self.op_names[self.sampled])(*args, **kwargs) # pylint: disable=invalid-sequence-index
def __len__(self):
return len(self.op_names)
......@@ -76,7 +76,7 @@ class PathSamplingInputChoice(nn.Module):
def forward(self, input_tensors):
if isinstance(self.sampled, list):
return sum([input_tensors[t] for t in self.sampled])
return sum([input_tensors[t] for t in self.sampled]) # pylint: disable=not-an-iterable
else:
return input_tensors[self.sampled]
......
......@@ -123,13 +123,13 @@ class AverageMeter:
return fmtstr.format(**self.__dict__)
def _replace_module_with_type(root_module, init_fn, type, modules):
def _replace_module_with_type(root_module, init_fn, type_name, modules):
if modules is None:
modules = []
def apply(m):
for name, child in m.named_children():
if isinstance(child, type):
if isinstance(child, type_name):
setattr(m, name, init_fn(child))
modules.append((child.key, getattr(m, name)))
else:
......
from collections import defaultdict
import inspect
from collections import defaultdict
from typing import Any
def import_(target: str, allow_none: bool = False) -> 'Any':
def import_(target: str, allow_none: bool = False) -> Any:
if target is None:
return None
path, identifier = target.rsplit('.', 1)
module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier)
_records = {}
def get_records():
global _records
return _records
def add_record(key, value):
"""
"""
......@@ -22,6 +27,7 @@ def add_record(key, value):
assert key not in _records, '{} already in _records'.format(key)
_records[key] = value
def _register_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
......@@ -31,14 +37,15 @@ def _register_module(original_class):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = args[i]
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def register_module():
"""
Register a module.
......@@ -68,14 +75,15 @@ def _register_trainer(original_class):
if isinstance(args[i], Module):
# ignore the base model object
continue
full_args[argname_list[i]] = args[i]
full_args[argname_list[i]] = arg
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
orig_init(self, *args, **kws) # Call the original __init__
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def register_trainer():
def _register(cls):
m = _register_trainer(
......@@ -84,8 +92,10 @@ def register_trainer():
return _register
_last_uid = defaultdict(int)
def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1
return _last_uid[namespace]
......@@ -41,7 +41,7 @@ jobs:
python3 -m pip install --upgrade pygments
python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install --upgrade tensorflow
python3 -m pip install --upgrade gym onnx peewee thop
python3 -m pip install --upgrade gym onnx peewee thop graphviz
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx
sudo apt-get install swig -y
python3 -m pip install -e .[SMAC,BOHB]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment