Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
59cd3982
Unverified
Commit
59cd3982
authored
Dec 14, 2020
by
Yuge Zhang
Committed by
GitHub
Dec 14, 2020
Browse files
[Retiarii] Coding style improvements for pylint and flake8 (#3190)
parent
593a275c
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
73 additions
and
47 deletions
+73
-47
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+17
-8
nni/retiarii/operation.py
nni/retiarii/operation.py
+5
-0
nni/retiarii/operation_def/tf_op_def.py
nni/retiarii/operation_def/tf_op_def.py
+2
-1
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+2
-0
nni/retiarii/strategies/strategy.py
nni/retiarii/strategies/strategy.py
+5
-1
nni/retiarii/strategies/tpe_strategy.py
nni/retiarii/strategies/tpe_strategy.py
+6
-8
nni/retiarii/trainer/interface.py
nni/retiarii/trainer/interface.py
+1
-2
nni/retiarii/trainer/pytorch/base.py
nni/retiarii/trainer/pytorch/base.py
+8
-9
nni/retiarii/trainer/pytorch/darts.py
nni/retiarii/trainer/pytorch/darts.py
+0
-1
nni/retiarii/trainer/pytorch/enas.py
nni/retiarii/trainer/pytorch/enas.py
+2
-2
nni/retiarii/trainer/pytorch/random.py
nni/retiarii/trainer/pytorch/random.py
+4
-4
nni/retiarii/trainer/pytorch/utils.py
nni/retiarii/trainer/pytorch/utils.py
+2
-2
nni/retiarii/utils.py
nni/retiarii/utils.py
+18
-8
pipelines/fast-test.yml
pipelines/fast-test.yml
+1
-1
No files found.
nni/retiarii/nn/pytorch/nn.py
View file @
59cd3982
import
inspect
import
inspect
import
logging
import
logging
from
typing
import
Any
,
List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
(
Any
,
Tuple
,
List
,
Optional
)
from
...utils
import
add_record
from
...utils
import
add_record
...
@@ -10,7 +11,7 @@ _logger = logging.getLogger(__name__)
...
@@ -10,7 +11,7 @@ _logger = logging.getLogger(__name__)
__all__
=
[
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'Placeholder'
,
'LayerChoice'
,
'InputChoice'
,
'Placeholder'
,
'Module'
,
'Sequential'
,
'ModuleList'
,
# TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Module'
,
'Sequential'
,
'ModuleList'
,
# TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity'
,
'Linear'
,
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
'Identity'
,
'Linear'
,
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Threshold'
,
'ReLU'
,
'Hardtanh'
,
'ReLU6'
,
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Threshold'
,
'ReLU'
,
'Hardtanh'
,
'ReLU6'
,
'Sigmoid'
,
'Tanh'
,
'Softmax'
,
'Softmax2d'
,
'LogSoftmax'
,
'ELU'
,
'SELU'
,
'CELU'
,
'GLU'
,
'GELU'
,
'Hardshrink'
,
'Sigmoid'
,
'Tanh'
,
'Softmax'
,
'Softmax2d'
,
'LogSoftmax'
,
'ELU'
,
'SELU'
,
'CELU'
,
'GLU'
,
'GELU'
,
'Hardshrink'
,
...
@@ -30,7 +31,7 @@ __all__ = [
...
@@ -30,7 +31,7 @@ __all__ = [
'TransformerEncoderLayer'
,
'TransformerDecoderLayer'
,
'Transformer'
,
'TransformerEncoderLayer'
,
'TransformerDecoderLayer'
,
'Transformer'
,
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten'
,
'Hardsigmoid'
,
'Hardswish'
'Flatten'
,
'Hardsigmoid'
,
'Hardswish'
]
]
...
@@ -57,9 +58,10 @@ class InputChoice(nn.Module):
...
@@ -57,9 +58,10 @@ class InputChoice(nn.Module):
if
n_candidates
or
choose_from
or
return_mask
:
if
n_candidates
or
choose_from
or
return_mask
:
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
def
forward
(
self
,
candidate_inputs
:
List
[
'
Tensor
'
])
->
'
Tensor
'
:
def
forward
(
self
,
candidate_inputs
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# fake return
# fake return
return
torch
.
tensor
(
candidate_inputs
)
return
torch
.
tensor
(
candidate_inputs
)
# pylint: disable=not-callable
class
ValueChoice
:
class
ValueChoice
:
"""
"""
...
@@ -67,6 +69,7 @@ class ValueChoice:
...
@@ -67,6 +69,7 @@ class ValueChoice:
when instantiating a pytorch module.
when instantiating a pytorch module.
TODO: can also be used in training approach
TODO: can also be used in training approach
"""
"""
def
__init__
(
self
,
candidate_values
:
List
[
Any
]):
def
__init__
(
self
,
candidate_values
:
List
[
Any
]):
self
.
candidate_values
=
candidate_values
self
.
candidate_values
=
candidate_values
...
@@ -81,6 +84,7 @@ class Placeholder(nn.Module):
...
@@ -81,6 +84,7 @@ class Placeholder(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
x
return
x
class
ChosenInputs
(
nn
.
Module
):
class
ChosenInputs
(
nn
.
Module
):
def
__init__
(
self
,
chosen
:
int
):
def
__init__
(
self
,
chosen
:
int
):
super
().
__init__
()
super
().
__init__
()
...
@@ -92,20 +96,24 @@ class ChosenInputs(nn.Module):
...
@@ -92,20 +96,24 @@ class ChosenInputs(nn.Module):
# the following are pytorch modules
# the following are pytorch modules
class
Module
(
nn
.
Module
):
class
Module
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Module
,
self
).
__init__
()
super
(
Module
,
self
).
__init__
()
class
Sequential
(
nn
.
Sequential
):
class
Sequential
(
nn
.
Sequential
):
def
__init__
(
self
,
*
args
):
def
__init__
(
self
,
*
args
):
add_record
(
id
(
self
),
{})
add_record
(
id
(
self
),
{})
super
(
Sequential
,
self
).
__init__
(
*
args
)
super
(
Sequential
,
self
).
__init__
(
*
args
)
class
ModuleList
(
nn
.
ModuleList
):
class
ModuleList
(
nn
.
ModuleList
):
def
__init__
(
self
,
*
args
):
def
__init__
(
self
,
*
args
):
add_record
(
id
(
self
),
{})
add_record
(
id
(
self
),
{})
super
(
ModuleList
,
self
).
__init__
(
*
args
)
super
(
ModuleList
,
self
).
__init__
(
*
args
)
def
wrap_module
(
original_class
):
def
wrap_module
(
original_class
):
orig_init
=
original_class
.
__init__
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
...
@@ -115,14 +123,15 @@ def wrap_module(original_class):
...
@@ -115,14 +123,15 @@ def wrap_module(original_class):
full_args
=
{}
full_args
=
{}
full_args
.
update
(
kws
)
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
s
[
i
]
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
return
original_class
# TODO: support different versions of pytorch
# TODO: support different versions of pytorch
Identity
=
wrap_module
(
nn
.
Identity
)
Identity
=
wrap_module
(
nn
.
Identity
)
Linear
=
wrap_module
(
nn
.
Linear
)
Linear
=
wrap_module
(
nn
.
Linear
)
...
...
nni/retiarii/operation.py
View file @
59cd3982
...
@@ -4,12 +4,14 @@ from . import debug_configs
...
@@ -4,12 +4,14 @@ from . import debug_configs
__all__
=
[
'Operation'
,
'Cell'
]
__all__
=
[
'Operation'
,
'Cell'
]
def
_convert_name
(
name
:
str
)
->
str
:
def
_convert_name
(
name
:
str
)
->
str
:
"""
"""
Convert the names using separator '.' to valid variable name in code
Convert the names using separator '.' to valid variable name in code
"""
"""
return
name
.
replace
(
'.'
,
'__'
)
return
name
.
replace
(
'.'
,
'__'
)
class
Operation
:
class
Operation
:
"""
"""
Calculation logic of a graph node.
Calculation logic of a graph node.
...
@@ -152,6 +154,7 @@ class PyTorchOperation(Operation):
...
@@ -152,6 +154,7 @@ class PyTorchOperation(Operation):
else
:
else
:
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
class
TensorFlowOperation
(
Operation
):
class
TensorFlowOperation
(
Operation
):
def
_to_class_name
(
self
)
->
str
:
def
_to_class_name
(
self
)
->
str
:
return
'K.layers.'
+
self
.
type
return
'K.layers.'
+
self
.
type
...
@@ -191,6 +194,7 @@ class Cell(PyTorchOperation):
...
@@ -191,6 +194,7 @@ class Cell(PyTorchOperation):
framework
framework
No real usage. Exists for compatibility with base class.
No real usage. Exists for compatibility with base class.
"""
"""
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{}):
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{}):
self
.
type
=
'_cell'
self
.
type
=
'_cell'
self
.
cell_name
=
cell_name
self
.
cell_name
=
cell_name
...
@@ -207,6 +211,7 @@ class _IOPseudoOperation(Operation):
...
@@ -207,6 +211,7 @@ class _IOPseudoOperation(Operation):
The benefit is that users no longer need to verify `Node.operation is not None`,
The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking.
especially in static type checking.
"""
"""
def
__init__
(
self
,
type_name
:
str
,
io_names
:
List
=
None
):
def
__init__
(
self
,
type_name
:
str
,
io_names
:
List
=
None
):
assert
type_name
.
startswith
(
'_'
)
assert
type_name
.
startswith
(
'_'
)
super
(
_IOPseudoOperation
,
self
).
__init__
(
type_name
,
{},
True
)
super
(
_IOPseudoOperation
,
self
).
__init__
(
type_name
,
{},
True
)
...
...
nni/retiarii/operation_def/tf_op_def.py
View file @
59cd3982
from
..operation
import
TensorFlowOperation
from
..operation
import
TensorFlowOperation
class
Conv2D
(
TensorFlowOperation
):
class
Conv2D
(
TensorFlowOperation
):
def
__init__
(
self
,
type_name
,
parameters
,
_internal
):
def
__init__
(
self
,
type_name
,
parameters
,
_internal
):
if
'padding'
not
in
parameters
:
if
'padding'
not
in
parameters
:
parameters
[
'padding'
]
=
'same'
parameters
[
'padding'
]
=
'same'
super
().
__init__
(
type_name
,
parameters
,
_internal
)
super
().
__init__
(
type_name
,
parameters
,
_internal
)
\ No newline at end of file
nni/retiarii/operation_def/torch_op_def.py
View file @
59cd3982
from
..operation
import
PyTorchOperation
from
..operation
import
PyTorchOperation
class
relu
(
PyTorchOperation
):
class
relu
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
def
to_init_code
(
self
,
field
):
return
''
return
''
...
@@ -17,6 +18,7 @@ class Flatten(PyTorchOperation):
...
@@ -17,6 +18,7 @@ class Flatten(PyTorchOperation):
assert
len
(
inputs
)
==
1
assert
len
(
inputs
)
==
1
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
0
]
}
.size(0), -1)'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
0
]
}
.size(0), -1)'
class
ToDevice
(
PyTorchOperation
):
class
ToDevice
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
def
to_init_code
(
self
,
field
):
return
''
return
''
...
...
nni/retiarii/strategies/strategy.py
View file @
59cd3982
import
abc
import
abc
from
typing
import
List
from
typing
import
List
from
..graph
import
Model
from
..mutator
import
Mutator
class
BaseStrategy
(
abc
.
ABC
):
class
BaseStrategy
(
abc
.
ABC
):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
run
(
self
,
base_model
:
'
Model
'
,
applied_mutators
:
List
[
'
Mutator
'
])
->
None
:
def
run
(
self
,
base_model
:
Model
,
applied_mutators
:
List
[
Mutator
])
->
None
:
pass
pass
nni/retiarii/strategies/tpe_strategy.py
View file @
59cd3982
import
json
import
logging
import
logging
import
random
import
os
from
..
import
Model
,
submit_models
,
wait_models
from
..
import
Sampler
,
submit_models
,
wait_models
from
..
import
Sampler
from
.strategy
import
BaseStrategy
from
.strategy
import
BaseStrategy
from
...algorithms.hpo.hyperopt_tuner.hyperopt_tuner
import
HyperoptTuner
from
...algorithms.hpo.hyperopt_tuner.hyperopt_tuner
import
HyperoptTuner
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
class
TPESampler
(
Sampler
):
class
TPESampler
(
Sampler
):
def
__init__
(
self
,
optimize_mode
=
'minimize'
):
def
__init__
(
self
,
optimize_mode
=
'minimize'
):
self
.
tpe_tuner
=
HyperoptTuner
(
'tpe'
,
optimize_mode
)
self
.
tpe_tuner
=
HyperoptTuner
(
'tpe'
,
optimize_mode
)
...
@@ -37,6 +34,7 @@ class TPESampler(Sampler):
...
@@ -37,6 +34,7 @@ class TPESampler(Sampler):
self
.
index
+=
1
self
.
index
+=
1
return
chosen
return
chosen
class
TPEStrategy
(
BaseStrategy
):
class
TPEStrategy
(
BaseStrategy
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
tpe_sampler
=
TPESampler
()
self
.
tpe_sampler
=
TPESampler
()
...
@@ -55,7 +53,7 @@ class TPEStrategy(BaseStrategy):
...
@@ -55,7 +53,7 @@ class TPEStrategy(BaseStrategy):
while
True
:
while
True
:
model
=
base_model
model
=
base_model
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'mutators:
{}'
.
format
(
applied_mutators
))
_logger
.
info
(
'mutators:
%s'
,
str
(
applied_mutators
))
self
.
tpe_sampler
.
generate_samples
(
self
.
model_id
)
self
.
tpe_sampler
.
generate_samples
(
self
.
model_id
)
for
mutator
in
applied_mutators
:
for
mutator
in
applied_mutators
:
_logger
.
info
(
'mutate model...'
)
_logger
.
info
(
'mutate model...'
)
...
@@ -66,6 +64,6 @@ class TPEStrategy(BaseStrategy):
...
@@ -66,6 +64,6 @@ class TPEStrategy(BaseStrategy):
wait_models
(
model
)
wait_models
(
model
)
self
.
tpe_sampler
.
receive_result
(
self
.
model_id
,
model
.
metric
)
self
.
tpe_sampler
.
receive_result
(
self
.
model_id
,
model
.
metric
)
self
.
model_id
+=
1
self
.
model_id
+=
1
_logger
.
info
(
'Strategy says:'
,
model
.
metric
)
_logger
.
info
(
'Strategy says:
%s
'
,
model
.
metric
)
except
Exception
as
e
:
except
Exception
:
_logger
.
error
(
logging
.
exception
(
'message'
))
_logger
.
error
(
logging
.
exception
(
'message'
))
nni/retiarii/trainer/interface.py
View file @
59cd3982
import
abc
import
abc
import
inspect
from
typing
import
Any
from
typing
import
*
class
BaseTrainer
(
abc
.
ABC
):
class
BaseTrainer
(
abc
.
ABC
):
...
...
nni/retiarii/trainer/pytorch/base.py
View file @
59cd3982
import
abc
from
typing
import
Any
,
List
,
Dict
,
Tuple
from
typing
import
*
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -42,6 +41,7 @@ def get_default_transform(dataset: str) -> Any:
...
@@ -42,6 +41,7 @@ def get_default_transform(dataset: str) -> Any:
# unsupported dataset, return None
# unsupported dataset, return None
return
None
return
None
@
register_trainer
()
@
register_trainer
()
class
PyTorchImageClassificationTrainer
(
BaseTrainer
):
class
PyTorchImageClassificationTrainer
(
BaseTrainer
):
"""
"""
...
@@ -94,7 +94,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
...
@@ -94,7 +94,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
self
.
_dataloader
=
DataLoader
(
self
.
_dataloader
=
DataLoader
(
self
.
_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_dataset
,
**
(
dataloader_kwargs
or
{}))
def
_accuracy
(
self
,
input
,
target
):
def
_accuracy
(
self
,
input
,
target
):
# pylint: disable=redefined-builtin
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
correct
=
predict
.
eq
(
target
.
data
).
cpu
().
sum
().
item
()
correct
=
predict
.
eq
(
target
.
data
).
cpu
().
sum
().
item
()
return
correct
/
input
.
size
(
0
)
return
correct
/
input
.
size
(
0
)
...
@@ -176,7 +176,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -176,7 +176,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
dataloader
=
DataLoader
(
dataset
,
**
(
dataloader_kwargs
or
{}))
dataloader
=
DataLoader
(
dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_datasets
.
append
(
dataset
)
self
.
_datasets
.
append
(
dataset
)
self
.
_dataloaders
.
append
(
dataloader
)
self
.
_dataloaders
.
append
(
dataloader
)
if
m
[
'use_output'
]:
if
m
[
'use_output'
]:
optimizer_cls
=
m
[
'optimizer_cls'
]
optimizer_cls
=
m
[
'optimizer_cls'
]
optimizer_kwargs
=
m
[
'optimizer_kwargs'
]
optimizer_kwargs
=
m
[
'optimizer_kwargs'
]
...
@@ -186,7 +186,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -186,7 +186,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
name_prefix
=
'_'
.
join
(
name
.
split
(
'_'
)[:
2
])
name_prefix
=
'_'
.
join
(
name
.
split
(
'_'
)[:
2
])
if
m_header
==
name_prefix
:
if
m_header
==
name_prefix
:
one_model_params
.
append
(
param
)
one_model_params
.
append
(
param
)
optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
one_model_params
,
**
(
optimizer_kwargs
or
{}))
optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
one_model_params
,
**
(
optimizer_kwargs
or
{}))
self
.
_optimizers
.
append
(
optimizer
)
self
.
_optimizers
.
append
(
optimizer
)
...
@@ -206,7 +206,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -206,7 +206,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
,
f
'cuda:
{
idx
}
'
)
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
,
f
'cuda:
{
idx
}
'
)
xs
.
append
(
x
)
xs
.
append
(
x
)
ys
.
append
(
y
)
ys
.
append
(
y
)
y_hats
=
self
.
multi_model
(
*
xs
)
y_hats
=
self
.
multi_model
(
*
xs
)
if
len
(
ys
)
!=
len
(
xs
):
if
len
(
ys
)
!=
len
(
xs
):
raise
ValueError
(
'len(ys) should be equal to len(xs)'
)
raise
ValueError
(
'len(ys) should be equal to len(xs)'
)
...
@@ -230,13 +230,12 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -230,13 +230,12 @@ class PyTorchMultiModelTrainer(BaseTrainer):
if
self
.
max_steps
and
batch_idx
>=
self
.
max_steps
:
if
self
.
max_steps
and
batch_idx
>=
self
.
max_steps
:
return
return
def
training_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
def
training_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
)
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
y_hat
=
self
.
model
(
x
)
return
self
.
training_step_after_model
(
x
,
y
,
y_hat
)
return
self
.
training_step_after_model
(
x
,
y
,
y_hat
)
def
training_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
def
training_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
x
,
y
=
batch
x
,
y
=
batch
if
device
:
if
device
:
x
,
y
=
x
.
cuda
(
torch
.
device
(
device
)),
y
.
cuda
(
torch
.
device
(
device
))
x
,
y
=
x
.
cuda
(
torch
.
device
(
device
)),
y
.
cuda
(
torch
.
device
(
device
))
...
@@ -259,4 +258,4 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -259,4 +258,4 @@ class PyTorchMultiModelTrainer(BaseTrainer):
def
validation_step_after_model
(
self
,
x
,
y
,
y_hat
):
def
validation_step_after_model
(
self
,
x
,
y
,
y_hat
):
acc
=
self
.
_accuracy
(
y_hat
,
y
)
acc
=
self
.
_accuracy
(
y_hat
,
y
)
return
{
'val_acc'
:
acc
}
return
{
'val_acc'
:
acc
}
\ No newline at end of file
nni/retiarii/trainer/pytorch/darts.py
View file @
59cd3982
...
@@ -6,7 +6,6 @@ import logging
...
@@ -6,7 +6,6 @@ import logging
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.nas.pytorch.mutables
import
LayerChoice
from
..interface
import
BaseOneShotTrainer
from
..interface
import
BaseOneShotTrainer
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
...
...
nni/retiarii/trainer/pytorch/enas.py
View file @
59cd3982
...
@@ -86,8 +86,8 @@ class ReinforceController(nn.Module):
...
@@ -86,8 +86,8 @@ class ReinforceController(nn.Module):
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
tensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
tensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
# pylint: disable=not-callable
requires_grad
=
False
)
# pylint: disable=not-callable
requires_grad
=
False
)
assert
entropy_reduction
in
[
'sum'
,
'mean'
],
'Entropy reduction must be one of sum and mean.'
assert
entropy_reduction
in
[
'sum'
,
'mean'
],
'Entropy reduction must be one of sum and mean.'
self
.
entropy_reduction
=
torch
.
sum
if
entropy_reduction
==
'sum'
else
torch
.
mean
self
.
entropy_reduction
=
torch
.
sum
if
entropy_reduction
==
'sum'
else
torch
.
mean
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
(
reduction
=
'none'
)
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
(
reduction
=
'none'
)
...
...
nni/retiarii/trainer/pytorch/random.py
View file @
59cd3982
...
@@ -16,7 +16,7 @@ _logger = logging.getLogger(__name__)
...
@@ -16,7 +16,7 @@ _logger = logging.getLogger(__name__)
def
_get_mask
(
sampled
,
total
):
def
_get_mask
(
sampled
,
total
):
multihot
=
[
i
==
sampled
or
(
isinstance
(
sampled
,
list
)
and
i
in
sampled
)
for
i
in
range
(
total
)]
multihot
=
[
i
==
sampled
or
(
isinstance
(
sampled
,
list
)
and
i
in
sampled
)
for
i
in
range
(
total
)]
return
torch
.
tensor
(
multihot
,
dtype
=
torch
.
bool
)
return
torch
.
tensor
(
multihot
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
class
PathSamplingLayerChoice
(
nn
.
Module
):
class
PathSamplingLayerChoice
(
nn
.
Module
):
...
@@ -44,9 +44,9 @@ class PathSamplingLayerChoice(nn.Module):
...
@@ -44,9 +44,9 @@ class PathSamplingLayerChoice(nn.Module):
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
assert
self
.
sampled
is
not
None
,
'At least one path needs to be sampled before fprop.'
assert
self
.
sampled
is
not
None
,
'At least one path needs to be sampled before fprop.'
if
isinstance
(
self
.
sampled
,
list
):
if
isinstance
(
self
.
sampled
,
list
):
return
sum
([
getattr
(
self
,
self
.
op_names
[
i
])(
*
args
,
**
kwargs
)
for
i
in
self
.
sampled
])
return
sum
([
getattr
(
self
,
self
.
op_names
[
i
])(
*
args
,
**
kwargs
)
for
i
in
self
.
sampled
])
# pylint: disable=not-an-iterable
else
:
else
:
return
getattr
(
self
,
self
.
op_names
[
self
.
sampled
])(
*
args
,
**
kwargs
)
return
getattr
(
self
,
self
.
op_names
[
self
.
sampled
])(
*
args
,
**
kwargs
)
# pylint: disable=invalid-sequence-index
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
op_names
)
return
len
(
self
.
op_names
)
...
@@ -76,7 +76,7 @@ class PathSamplingInputChoice(nn.Module):
...
@@ -76,7 +76,7 @@ class PathSamplingInputChoice(nn.Module):
def
forward
(
self
,
input_tensors
):
def
forward
(
self
,
input_tensors
):
if
isinstance
(
self
.
sampled
,
list
):
if
isinstance
(
self
.
sampled
,
list
):
return
sum
([
input_tensors
[
t
]
for
t
in
self
.
sampled
])
return
sum
([
input_tensors
[
t
]
for
t
in
self
.
sampled
])
# pylint: disable=not-an-iterable
else
:
else
:
return
input_tensors
[
self
.
sampled
]
return
input_tensors
[
self
.
sampled
]
...
...
nni/retiarii/trainer/pytorch/utils.py
View file @
59cd3982
...
@@ -123,13 +123,13 @@ class AverageMeter:
...
@@ -123,13 +123,13 @@ class AverageMeter:
return
fmtstr
.
format
(
**
self
.
__dict__
)
return
fmtstr
.
format
(
**
self
.
__dict__
)
def
_replace_module_with_type
(
root_module
,
init_fn
,
type
,
modules
):
def
_replace_module_with_type
(
root_module
,
init_fn
,
type
_name
,
modules
):
if
modules
is
None
:
if
modules
is
None
:
modules
=
[]
modules
=
[]
def
apply
(
m
):
def
apply
(
m
):
for
name
,
child
in
m
.
named_children
():
for
name
,
child
in
m
.
named_children
():
if
isinstance
(
child
,
type
):
if
isinstance
(
child
,
type
_name
):
setattr
(
m
,
name
,
init_fn
(
child
))
setattr
(
m
,
name
,
init_fn
(
child
))
modules
.
append
((
child
.
key
,
getattr
(
m
,
name
)))
modules
.
append
((
child
.
key
,
getattr
(
m
,
name
)))
else
:
else
:
...
...
nni/retiarii/utils.py
View file @
59cd3982
from
collections
import
defaultdict
import
inspect
import
inspect
from
collections
import
defaultdict
from
typing
import
Any
def
import_
(
target
:
str
,
allow_none
:
bool
=
False
)
->
'
Any
'
:
def
import_
(
target
:
str
,
allow_none
:
bool
=
False
)
->
Any
:
if
target
is
None
:
if
target
is
None
:
return
None
return
None
path
,
identifier
=
target
.
rsplit
(
'.'
,
1
)
path
,
identifier
=
target
.
rsplit
(
'.'
,
1
)
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
return
getattr
(
module
,
identifier
)
return
getattr
(
module
,
identifier
)
_records
=
{}
_records
=
{}
def
get_records
():
def
get_records
():
global
_records
global
_records
return
_records
return
_records
def
add_record
(
key
,
value
):
def
add_record
(
key
,
value
):
"""
"""
"""
"""
...
@@ -22,6 +27,7 @@ def add_record(key, value):
...
@@ -22,6 +27,7 @@ def add_record(key, value):
assert
key
not
in
_records
,
'{} already in _records'
.
format
(
key
)
assert
key
not
in
_records
,
'{} already in _records'
.
format
(
key
)
_records
[
key
]
=
value
_records
[
key
]
=
value
def
_register_module
(
original_class
):
def
_register_module
(
original_class
):
orig_init
=
original_class
.
__init__
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
...
@@ -31,14 +37,15 @@ def _register_module(original_class):
...
@@ -31,14 +37,15 @@ def _register_module(original_class):
full_args
=
{}
full_args
=
{}
full_args
.
update
(
kws
)
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
s
[
i
]
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
return
original_class
def
register_module
():
def
register_module
():
"""
"""
Register a module.
Register a module.
...
@@ -68,14 +75,15 @@ def _register_trainer(original_class):
...
@@ -68,14 +75,15 @@ def _register_trainer(original_class):
if
isinstance
(
args
[
i
],
Module
):
if
isinstance
(
args
[
i
],
Module
):
# ignore the base model object
# ignore the base model object
continue
continue
full_args
[
argname_list
[
i
]]
=
arg
s
[
i
]
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
return
original_class
def
register_trainer
():
def
register_trainer
():
def
_register
(
cls
):
def
_register
(
cls
):
m
=
_register_trainer
(
m
=
_register_trainer
(
...
@@ -84,8 +92,10 @@ def register_trainer():
...
@@ -84,8 +92,10 @@ def register_trainer():
return
_register
return
_register
_last_uid
=
defaultdict
(
int
)
_last_uid
=
defaultdict
(
int
)
def
uid
(
namespace
:
str
=
'default'
)
->
int
:
def
uid
(
namespace
:
str
=
'default'
)
->
int
:
_last_uid
[
namespace
]
+=
1
_last_uid
[
namespace
]
+=
1
return
_last_uid
[
namespace
]
return
_last_uid
[
namespace
]
pipelines/fast-test.yml
View file @
59cd3982
...
@@ -41,7 +41,7 @@ jobs:
...
@@ -41,7 +41,7 @@ jobs:
python3 -m pip install --upgrade pygments
python3 -m pip install --upgrade pygments
python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install --upgrade tensorflow
python3 -m pip install --upgrade tensorflow
python3 -m pip install --upgrade gym onnx peewee thop
python3 -m pip install --upgrade gym onnx peewee thop
graphviz
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx
sudo apt-get install swig -y
sudo apt-get install swig -y
python3 -m pip install -e .[SMAC,BOHB]
python3 -m pip install -e .[SMAC,BOHB]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment