Unverified Commit 59cd3982 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Coding style improvements for pylint and flake8 (#3190)

parent 593a275c
import inspect import inspect
import logging import logging
from typing import Any, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import (Any, Tuple, List, Optional)
from ...utils import add_record from ...utils import add_record
...@@ -10,7 +11,7 @@ _logger = logging.getLogger(__name__) ...@@ -10,7 +11,7 @@ _logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'LayerChoice', 'InputChoice', 'Placeholder', 'LayerChoice', 'InputChoice', 'Placeholder',
'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict', 'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6', 'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6',
'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink', 'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink',
...@@ -30,7 +31,7 @@ __all__ = [ ...@@ -30,7 +31,7 @@ __all__ = [
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', #'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d', #'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle', #'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten', 'Hardsigmoid', 'Hardswish' 'Flatten', 'Hardsigmoid', 'Hardswish'
] ]
...@@ -57,9 +58,10 @@ class InputChoice(nn.Module): ...@@ -57,9 +58,10 @@ class InputChoice(nn.Module):
if n_candidates or choose_from or return_mask: if n_candidates or choose_from or return_mask:
_logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!') _logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!')
def forward(self, candidate_inputs: List['Tensor']) -> 'Tensor': def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor:
# fake return # fake return
return torch.tensor(candidate_inputs) return torch.tensor(candidate_inputs) # pylint: disable=not-callable
class ValueChoice: class ValueChoice:
""" """
...@@ -67,6 +69,7 @@ class ValueChoice: ...@@ -67,6 +69,7 @@ class ValueChoice:
when instantiating a pytorch module. when instantiating a pytorch module.
TODO: can also be used in training approach TODO: can also be used in training approach
""" """
def __init__(self, candidate_values: List[Any]): def __init__(self, candidate_values: List[Any]):
self.candidate_values = candidate_values self.candidate_values = candidate_values
...@@ -81,6 +84,7 @@ class Placeholder(nn.Module): ...@@ -81,6 +84,7 @@ class Placeholder(nn.Module):
def forward(self, x): def forward(self, x):
return x return x
class ChosenInputs(nn.Module): class ChosenInputs(nn.Module):
def __init__(self, chosen: int): def __init__(self, chosen: int):
super().__init__() super().__init__()
...@@ -92,20 +96,24 @@ class ChosenInputs(nn.Module): ...@@ -92,20 +96,24 @@ class ChosenInputs(nn.Module):
# the following are pytorch modules # the following are pytorch modules
class Module(nn.Module): class Module(nn.Module):
def __init__(self): def __init__(self):
super(Module, self).__init__() super(Module, self).__init__()
class Sequential(nn.Sequential): class Sequential(nn.Sequential):
def __init__(self, *args): def __init__(self, *args):
add_record(id(self), {}) add_record(id(self), {})
super(Sequential, self).__init__(*args) super(Sequential, self).__init__(*args)
class ModuleList(nn.ModuleList): class ModuleList(nn.ModuleList):
def __init__(self, *args): def __init__(self, *args):
add_record(id(self), {}) add_record(id(self), {})
super(ModuleList, self).__init__(*args) super(ModuleList, self).__init__(*args)
def wrap_module(original_class): def wrap_module(original_class):
orig_init = original_class.__init__ orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys()) argname_list = list(inspect.signature(original_class).parameters.keys())
...@@ -115,14 +123,15 @@ def wrap_module(original_class): ...@@ -115,14 +123,15 @@ def wrap_module(original_class):
full_args = {} full_args = {}
full_args.update(kws) full_args.update(kws)
for i, arg in enumerate(args): for i, arg in enumerate(args):
full_args[argname_list[i]] = args[i] full_args[argname_list[i]] = arg
add_record(id(self), full_args) add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__ orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class return original_class
# TODO: support different versions of pytorch # TODO: support different versions of pytorch
Identity = wrap_module(nn.Identity) Identity = wrap_module(nn.Identity)
Linear = wrap_module(nn.Linear) Linear = wrap_module(nn.Linear)
......
...@@ -4,12 +4,14 @@ from . import debug_configs ...@@ -4,12 +4,14 @@ from . import debug_configs
__all__ = ['Operation', 'Cell'] __all__ = ['Operation', 'Cell']
def _convert_name(name: str) -> str: def _convert_name(name: str) -> str:
""" """
Convert the names using separator '.' to valid variable name in code Convert the names using separator '.' to valid variable name in code
""" """
return name.replace('.', '__') return name.replace('.', '__')
class Operation: class Operation:
""" """
Calculation logic of a graph node. Calculation logic of a graph node.
...@@ -152,6 +154,7 @@ class PyTorchOperation(Operation): ...@@ -152,6 +154,7 @@ class PyTorchOperation(Operation):
else: else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}') raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
class TensorFlowOperation(Operation): class TensorFlowOperation(Operation):
def _to_class_name(self) -> str: def _to_class_name(self) -> str:
return 'K.layers.' + self.type return 'K.layers.' + self.type
...@@ -191,6 +194,7 @@ class Cell(PyTorchOperation): ...@@ -191,6 +194,7 @@ class Cell(PyTorchOperation):
framework framework
No real usage. Exists for compatibility with base class. No real usage. Exists for compatibility with base class.
""" """
def __init__(self, cell_name: str, parameters: Dict[str, Any] = {}): def __init__(self, cell_name: str, parameters: Dict[str, Any] = {}):
self.type = '_cell' self.type = '_cell'
self.cell_name = cell_name self.cell_name = cell_name
...@@ -207,6 +211,7 @@ class _IOPseudoOperation(Operation): ...@@ -207,6 +211,7 @@ class _IOPseudoOperation(Operation):
The benefit is that users no longer need to verify `Node.operation is not None`, The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking. especially in static type checking.
""" """
def __init__(self, type_name: str, io_names: List = None): def __init__(self, type_name: str, io_names: List = None):
assert type_name.startswith('_') assert type_name.startswith('_')
super(_IOPseudoOperation, self).__init__(type_name, {}, True) super(_IOPseudoOperation, self).__init__(type_name, {}, True)
......
from ..operation import TensorFlowOperation from ..operation import TensorFlowOperation
class Conv2D(TensorFlowOperation): class Conv2D(TensorFlowOperation):
def __init__(self, type_name, parameters, _internal): def __init__(self, type_name, parameters, _internal):
if 'padding' not in parameters: if 'padding' not in parameters:
parameters['padding'] = 'same' parameters['padding'] = 'same'
super().__init__(type_name, parameters, _internal) super().__init__(type_name, parameters, _internal)
\ No newline at end of file
from ..operation import PyTorchOperation from ..operation import PyTorchOperation
class relu(PyTorchOperation): class relu(PyTorchOperation):
def to_init_code(self, field): def to_init_code(self, field):
return '' return ''
...@@ -17,6 +18,7 @@ class Flatten(PyTorchOperation): ...@@ -17,6 +18,7 @@ class Flatten(PyTorchOperation):
assert len(inputs) == 1 assert len(inputs) == 1
return f'{output} = {inputs[0]}.view({inputs[0]}.size(0), -1)' return f'{output} = {inputs[0]}.view({inputs[0]}.size(0), -1)'
class ToDevice(PyTorchOperation): class ToDevice(PyTorchOperation):
def to_init_code(self, field): def to_init_code(self, field):
return '' return ''
......
import abc import abc
from typing import List from typing import List
from ..graph import Model
from ..mutator import Mutator
class BaseStrategy(abc.ABC): class BaseStrategy(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def run(self, base_model: 'Model', applied_mutators: List['Mutator']) -> None: def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None:
pass pass
import json
import logging import logging
import random
import os
from .. import Model, submit_models, wait_models from .. import Sampler, submit_models, wait_models
from .. import Sampler
from .strategy import BaseStrategy from .strategy import BaseStrategy
from ...algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner from ...algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class TPESampler(Sampler): class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'): def __init__(self, optimize_mode='minimize'):
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode) self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
...@@ -37,6 +34,7 @@ class TPESampler(Sampler): ...@@ -37,6 +34,7 @@ class TPESampler(Sampler):
self.index += 1 self.index += 1
return chosen return chosen
class TPEStrategy(BaseStrategy): class TPEStrategy(BaseStrategy):
def __init__(self): def __init__(self):
self.tpe_sampler = TPESampler() self.tpe_sampler = TPESampler()
...@@ -55,7 +53,7 @@ class TPEStrategy(BaseStrategy): ...@@ -55,7 +53,7 @@ class TPEStrategy(BaseStrategy):
while True: while True:
model = base_model model = base_model
_logger.info('apply mutators...') _logger.info('apply mutators...')
_logger.info('mutators: {}'.format(applied_mutators)) _logger.info('mutators: %s', str(applied_mutators))
self.tpe_sampler.generate_samples(self.model_id) self.tpe_sampler.generate_samples(self.model_id)
for mutator in applied_mutators: for mutator in applied_mutators:
_logger.info('mutate model...') _logger.info('mutate model...')
...@@ -66,6 +64,6 @@ class TPEStrategy(BaseStrategy): ...@@ -66,6 +64,6 @@ class TPEStrategy(BaseStrategy):
wait_models(model) wait_models(model)
self.tpe_sampler.receive_result(self.model_id, model.metric) self.tpe_sampler.receive_result(self.model_id, model.metric)
self.model_id += 1 self.model_id += 1
_logger.info('Strategy says:', model.metric) _logger.info('Strategy says: %s', model.metric)
except Exception as e: except Exception:
_logger.error(logging.exception('message')) _logger.error(logging.exception('message'))
import abc import abc
import inspect from typing import Any
from typing import *
class BaseTrainer(abc.ABC): class BaseTrainer(abc.ABC):
......
import abc from typing import Any, List, Dict, Tuple
from typing import *
import numpy as np import numpy as np
import torch import torch
...@@ -42,6 +41,7 @@ def get_default_transform(dataset: str) -> Any: ...@@ -42,6 +41,7 @@ def get_default_transform(dataset: str) -> Any:
# unsupported dataset, return None # unsupported dataset, return None
return None return None
@register_trainer() @register_trainer()
class PyTorchImageClassificationTrainer(BaseTrainer): class PyTorchImageClassificationTrainer(BaseTrainer):
""" """
...@@ -94,7 +94,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer): ...@@ -94,7 +94,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
self._dataloader = DataLoader( self._dataloader = DataLoader(
self._dataset, **(dataloader_kwargs or {})) self._dataset, **(dataloader_kwargs or {}))
def _accuracy(self, input, target): def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1) _, predict = torch.max(input.data, 1)
correct = predict.eq(target.data).cpu().sum().item() correct = predict.eq(target.data).cpu().sum().item()
return correct / input.size(0) return correct / input.size(0)
...@@ -176,7 +176,7 @@ class PyTorchMultiModelTrainer(BaseTrainer): ...@@ -176,7 +176,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
dataloader = DataLoader(dataset, **(dataloader_kwargs or {})) dataloader = DataLoader(dataset, **(dataloader_kwargs or {}))
self._datasets.append(dataset) self._datasets.append(dataset)
self._dataloaders.append(dataloader) self._dataloaders.append(dataloader)
if m['use_output']: if m['use_output']:
optimizer_cls = m['optimizer_cls'] optimizer_cls = m['optimizer_cls']
optimizer_kwargs = m['optimizer_kwargs'] optimizer_kwargs = m['optimizer_kwargs']
...@@ -186,7 +186,7 @@ class PyTorchMultiModelTrainer(BaseTrainer): ...@@ -186,7 +186,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
name_prefix = '_'.join(name.split('_')[:2]) name_prefix = '_'.join(name.split('_')[:2])
if m_header == name_prefix: if m_header == name_prefix:
one_model_params.append(param) one_model_params.append(param)
optimizer = getattr(torch.optim, optimizer_cls)(one_model_params, **(optimizer_kwargs or {})) optimizer = getattr(torch.optim, optimizer_cls)(one_model_params, **(optimizer_kwargs or {}))
self._optimizers.append(optimizer) self._optimizers.append(optimizer)
...@@ -206,7 +206,7 @@ class PyTorchMultiModelTrainer(BaseTrainer): ...@@ -206,7 +206,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}') x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}')
xs.append(x) xs.append(x)
ys.append(y) ys.append(y)
y_hats = self.multi_model(*xs) y_hats = self.multi_model(*xs)
if len(ys) != len(xs): if len(ys) != len(xs):
raise ValueError('len(ys) should be equal to len(xs)') raise ValueError('len(ys) should be equal to len(xs)')
...@@ -230,13 +230,12 @@ class PyTorchMultiModelTrainer(BaseTrainer): ...@@ -230,13 +230,12 @@ class PyTorchMultiModelTrainer(BaseTrainer):
if self.max_steps and batch_idx >= self.max_steps: if self.max_steps and batch_idx >= self.max_steps:
return return
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]: def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.training_step_before_model(batch, batch_idx) x, y = self.training_step_before_model(batch, batch_idx)
y_hat = self.model(x) y_hat = self.model(x)
return self.training_step_after_model(x, y, y_hat) return self.training_step_after_model(x, y, y_hat)
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device = None): def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch x, y = batch
if device: if device:
x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device)) x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device))
...@@ -259,4 +258,4 @@ class PyTorchMultiModelTrainer(BaseTrainer): ...@@ -259,4 +258,4 @@ class PyTorchMultiModelTrainer(BaseTrainer):
def validation_step_after_model(self, x, y, y_hat): def validation_step_after_model(self, x, y, y_hat):
acc = self._accuracy(y_hat, y) acc = self._accuracy(y_hat, y)
return {'val_acc': acc} return {'val_acc': acc}
\ No newline at end of file
...@@ -6,7 +6,6 @@ import logging ...@@ -6,7 +6,6 @@ import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.nas.pytorch.mutables import LayerChoice
from ..interface import BaseOneShotTrainer from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
......
...@@ -86,8 +86,8 @@ class ReinforceController(nn.Module): ...@@ -86,8 +86,8 @@ class ReinforceController(nn.Module):
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
requires_grad=False) # pylint: disable=not-callable requires_grad=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.' assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none') self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
......
...@@ -16,7 +16,7 @@ _logger = logging.getLogger(__name__) ...@@ -16,7 +16,7 @@ _logger = logging.getLogger(__name__)
def _get_mask(sampled, total): def _get_mask(sampled, total):
multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)] multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)]
return torch.tensor(multihot, dtype=torch.bool) return torch.tensor(multihot, dtype=torch.bool) # pylint: disable=not-callable
class PathSamplingLayerChoice(nn.Module): class PathSamplingLayerChoice(nn.Module):
...@@ -44,9 +44,9 @@ class PathSamplingLayerChoice(nn.Module): ...@@ -44,9 +44,9 @@ class PathSamplingLayerChoice(nn.Module):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
assert self.sampled is not None, 'At least one path needs to be sampled before fprop.' assert self.sampled is not None, 'At least one path needs to be sampled before fprop.'
if isinstance(self.sampled, list): if isinstance(self.sampled, list):
return sum([getattr(self, self.op_names[i])(*args, **kwargs) for i in self.sampled]) return sum([getattr(self, self.op_names[i])(*args, **kwargs) for i in self.sampled]) # pylint: disable=not-an-iterable
else: else:
return getattr(self, self.op_names[self.sampled])(*args, **kwargs) return getattr(self, self.op_names[self.sampled])(*args, **kwargs) # pylint: disable=invalid-sequence-index
def __len__(self): def __len__(self):
return len(self.op_names) return len(self.op_names)
...@@ -76,7 +76,7 @@ class PathSamplingInputChoice(nn.Module): ...@@ -76,7 +76,7 @@ class PathSamplingInputChoice(nn.Module):
def forward(self, input_tensors): def forward(self, input_tensors):
if isinstance(self.sampled, list): if isinstance(self.sampled, list):
return sum([input_tensors[t] for t in self.sampled]) return sum([input_tensors[t] for t in self.sampled]) # pylint: disable=not-an-iterable
else: else:
return input_tensors[self.sampled] return input_tensors[self.sampled]
......
...@@ -123,13 +123,13 @@ class AverageMeter: ...@@ -123,13 +123,13 @@ class AverageMeter:
return fmtstr.format(**self.__dict__) return fmtstr.format(**self.__dict__)
def _replace_module_with_type(root_module, init_fn, type, modules): def _replace_module_with_type(root_module, init_fn, type_name, modules):
if modules is None: if modules is None:
modules = [] modules = []
def apply(m): def apply(m):
for name, child in m.named_children(): for name, child in m.named_children():
if isinstance(child, type): if isinstance(child, type_name):
setattr(m, name, init_fn(child)) setattr(m, name, init_fn(child))
modules.append((child.key, getattr(m, name))) modules.append((child.key, getattr(m, name)))
else: else:
......
from collections import defaultdict
import inspect import inspect
from collections import defaultdict
from typing import Any
def import_(target: str, allow_none: bool = False) -> 'Any': def import_(target: str, allow_none: bool = False) -> Any:
if target is None: if target is None:
return None return None
path, identifier = target.rsplit('.', 1) path, identifier = target.rsplit('.', 1)
module = __import__(path, globals(), locals(), [identifier]) module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier) return getattr(module, identifier)
_records = {} _records = {}
def get_records(): def get_records():
global _records global _records
return _records return _records
def add_record(key, value): def add_record(key, value):
""" """
""" """
...@@ -22,6 +27,7 @@ def add_record(key, value): ...@@ -22,6 +27,7 @@ def add_record(key, value):
assert key not in _records, '{} already in _records'.format(key) assert key not in _records, '{} already in _records'.format(key)
_records[key] = value _records[key] = value
def _register_module(original_class): def _register_module(original_class):
orig_init = original_class.__init__ orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys()) argname_list = list(inspect.signature(original_class).parameters.keys())
...@@ -31,14 +37,15 @@ def _register_module(original_class): ...@@ -31,14 +37,15 @@ def _register_module(original_class):
full_args = {} full_args = {}
full_args.update(kws) full_args.update(kws)
for i, arg in enumerate(args): for i, arg in enumerate(args):
full_args[argname_list[i]] = args[i] full_args[argname_list[i]] = arg
add_record(id(self), full_args) add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__ orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class return original_class
def register_module(): def register_module():
""" """
Register a module. Register a module.
...@@ -68,14 +75,15 @@ def _register_trainer(original_class): ...@@ -68,14 +75,15 @@ def _register_trainer(original_class):
if isinstance(args[i], Module): if isinstance(args[i], Module):
# ignore the base model object # ignore the base model object
continue continue
full_args[argname_list[i]] = args[i] full_args[argname_list[i]] = arg
add_record(id(self), {'modulename': full_class_name, 'args': full_args}) add_record(id(self), {'modulename': full_class_name, 'args': full_args})
orig_init(self, *args, **kws) # Call the original __init__ orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class return original_class
def register_trainer(): def register_trainer():
def _register(cls): def _register(cls):
m = _register_trainer( m = _register_trainer(
...@@ -84,8 +92,10 @@ def register_trainer(): ...@@ -84,8 +92,10 @@ def register_trainer():
return _register return _register
_last_uid = defaultdict(int) _last_uid = defaultdict(int)
def uid(namespace: str = 'default') -> int: def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1 _last_uid[namespace] += 1
return _last_uid[namespace] return _last_uid[namespace]
...@@ -41,7 +41,7 @@ jobs: ...@@ -41,7 +41,7 @@ jobs:
python3 -m pip install --upgrade pygments python3 -m pip install --upgrade pygments
python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install --upgrade tensorflow python3 -m pip install --upgrade tensorflow
python3 -m pip install --upgrade gym onnx peewee thop python3 -m pip install --upgrade gym onnx peewee thop graphviz
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx
sudo apt-get install swig -y sudo apt-get install swig -y
python3 -m pip install -e .[SMAC,BOHB] python3 -m pip install -e .[SMAC,BOHB]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment