Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
59cd3982
Unverified
Commit
59cd3982
authored
Dec 14, 2020
by
Yuge Zhang
Committed by
GitHub
Dec 14, 2020
Browse files
[Retiarii] Coding style improvements for pylint and flake8 (#3190)
parent
593a275c
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
73 additions
and
47 deletions
+73
-47
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+17
-8
nni/retiarii/operation.py
nni/retiarii/operation.py
+5
-0
nni/retiarii/operation_def/tf_op_def.py
nni/retiarii/operation_def/tf_op_def.py
+2
-1
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+2
-0
nni/retiarii/strategies/strategy.py
nni/retiarii/strategies/strategy.py
+5
-1
nni/retiarii/strategies/tpe_strategy.py
nni/retiarii/strategies/tpe_strategy.py
+6
-8
nni/retiarii/trainer/interface.py
nni/retiarii/trainer/interface.py
+1
-2
nni/retiarii/trainer/pytorch/base.py
nni/retiarii/trainer/pytorch/base.py
+8
-9
nni/retiarii/trainer/pytorch/darts.py
nni/retiarii/trainer/pytorch/darts.py
+0
-1
nni/retiarii/trainer/pytorch/enas.py
nni/retiarii/trainer/pytorch/enas.py
+2
-2
nni/retiarii/trainer/pytorch/random.py
nni/retiarii/trainer/pytorch/random.py
+4
-4
nni/retiarii/trainer/pytorch/utils.py
nni/retiarii/trainer/pytorch/utils.py
+2
-2
nni/retiarii/utils.py
nni/retiarii/utils.py
+18
-8
pipelines/fast-test.yml
pipelines/fast-test.yml
+1
-1
No files found.
nni/retiarii/nn/pytorch/nn.py
View file @
59cd3982
import
inspect
import
inspect
import
logging
import
logging
from
typing
import
Any
,
List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
(
Any
,
Tuple
,
List
,
Optional
)
from
...utils
import
add_record
from
...utils
import
add_record
...
@@ -10,7 +11,7 @@ _logger = logging.getLogger(__name__)
...
@@ -10,7 +11,7 @@ _logger = logging.getLogger(__name__)
__all__
=
[
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'Placeholder'
,
'LayerChoice'
,
'InputChoice'
,
'Placeholder'
,
'Module'
,
'Sequential'
,
'ModuleList'
,
# TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Module'
,
'Sequential'
,
'ModuleList'
,
# TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity'
,
'Linear'
,
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
'Identity'
,
'Linear'
,
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Threshold'
,
'ReLU'
,
'Hardtanh'
,
'ReLU6'
,
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Threshold'
,
'ReLU'
,
'Hardtanh'
,
'ReLU6'
,
'Sigmoid'
,
'Tanh'
,
'Softmax'
,
'Softmax2d'
,
'LogSoftmax'
,
'ELU'
,
'SELU'
,
'CELU'
,
'GLU'
,
'GELU'
,
'Hardshrink'
,
'Sigmoid'
,
'Tanh'
,
'Softmax'
,
'Softmax2d'
,
'LogSoftmax'
,
'ELU'
,
'SELU'
,
'CELU'
,
'GLU'
,
'GELU'
,
'Hardshrink'
,
...
@@ -30,7 +31,7 @@ __all__ = [
...
@@ -30,7 +31,7 @@ __all__ = [
'TransformerEncoderLayer'
,
'TransformerDecoderLayer'
,
'Transformer'
,
'TransformerEncoderLayer'
,
'TransformerDecoderLayer'
,
'Transformer'
,
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten'
,
'Hardsigmoid'
,
'Hardswish'
'Flatten'
,
'Hardsigmoid'
,
'Hardswish'
]
]
...
@@ -57,9 +58,10 @@ class InputChoice(nn.Module):
...
@@ -57,9 +58,10 @@ class InputChoice(nn.Module):
if
n_candidates
or
choose_from
or
return_mask
:
if
n_candidates
or
choose_from
or
return_mask
:
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
def
forward
(
self
,
candidate_inputs
:
List
[
'
Tensor
'
])
->
'
Tensor
'
:
def
forward
(
self
,
candidate_inputs
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# fake return
# fake return
return
torch
.
tensor
(
candidate_inputs
)
return
torch
.
tensor
(
candidate_inputs
)
# pylint: disable=not-callable
class
ValueChoice
:
class
ValueChoice
:
"""
"""
...
@@ -67,6 +69,7 @@ class ValueChoice:
...
@@ -67,6 +69,7 @@ class ValueChoice:
when instantiating a pytorch module.
when instantiating a pytorch module.
TODO: can also be used in training approach
TODO: can also be used in training approach
"""
"""
def
__init__
(
self
,
candidate_values
:
List
[
Any
]):
def
__init__
(
self
,
candidate_values
:
List
[
Any
]):
self
.
candidate_values
=
candidate_values
self
.
candidate_values
=
candidate_values
...
@@ -81,6 +84,7 @@ class Placeholder(nn.Module):
...
@@ -81,6 +84,7 @@ class Placeholder(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
x
return
x
class
ChosenInputs
(
nn
.
Module
):
class
ChosenInputs
(
nn
.
Module
):
def
__init__
(
self
,
chosen
:
int
):
def
__init__
(
self
,
chosen
:
int
):
super
().
__init__
()
super
().
__init__
()
...
@@ -92,20 +96,24 @@ class ChosenInputs(nn.Module):
...
@@ -92,20 +96,24 @@ class ChosenInputs(nn.Module):
# the following are pytorch modules
# the following are pytorch modules
class
Module
(
nn
.
Module
):
class
Module
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Module
,
self
).
__init__
()
super
(
Module
,
self
).
__init__
()
class
Sequential
(
nn
.
Sequential
):
class
Sequential
(
nn
.
Sequential
):
def
__init__
(
self
,
*
args
):
def
__init__
(
self
,
*
args
):
add_record
(
id
(
self
),
{})
add_record
(
id
(
self
),
{})
super
(
Sequential
,
self
).
__init__
(
*
args
)
super
(
Sequential
,
self
).
__init__
(
*
args
)
class
ModuleList
(
nn
.
ModuleList
):
class
ModuleList
(
nn
.
ModuleList
):
def
__init__
(
self
,
*
args
):
def
__init__
(
self
,
*
args
):
add_record
(
id
(
self
),
{})
add_record
(
id
(
self
),
{})
super
(
ModuleList
,
self
).
__init__
(
*
args
)
super
(
ModuleList
,
self
).
__init__
(
*
args
)
def
wrap_module
(
original_class
):
def
wrap_module
(
original_class
):
orig_init
=
original_class
.
__init__
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
...
@@ -115,14 +123,15 @@ def wrap_module(original_class):
...
@@ -115,14 +123,15 @@ def wrap_module(original_class):
full_args
=
{}
full_args
=
{}
full_args
.
update
(
kws
)
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
s
[
i
]
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
return
original_class
# TODO: support different versions of pytorch
# TODO: support different versions of pytorch
Identity
=
wrap_module
(
nn
.
Identity
)
Identity
=
wrap_module
(
nn
.
Identity
)
Linear
=
wrap_module
(
nn
.
Linear
)
Linear
=
wrap_module
(
nn
.
Linear
)
...
...
nni/retiarii/operation.py
View file @
59cd3982
...
@@ -4,12 +4,14 @@ from . import debug_configs
...
@@ -4,12 +4,14 @@ from . import debug_configs
__all__
=
[
'Operation'
,
'Cell'
]
__all__
=
[
'Operation'
,
'Cell'
]
def
_convert_name
(
name
:
str
)
->
str
:
def
_convert_name
(
name
:
str
)
->
str
:
"""
"""
Convert the names using separator '.' to valid variable name in code
Convert the names using separator '.' to valid variable name in code
"""
"""
return
name
.
replace
(
'.'
,
'__'
)
return
name
.
replace
(
'.'
,
'__'
)
class
Operation
:
class
Operation
:
"""
"""
Calculation logic of a graph node.
Calculation logic of a graph node.
...
@@ -152,6 +154,7 @@ class PyTorchOperation(Operation):
...
@@ -152,6 +154,7 @@ class PyTorchOperation(Operation):
else
:
else
:
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
class
TensorFlowOperation
(
Operation
):
class
TensorFlowOperation
(
Operation
):
def
_to_class_name
(
self
)
->
str
:
def
_to_class_name
(
self
)
->
str
:
return
'K.layers.'
+
self
.
type
return
'K.layers.'
+
self
.
type
...
@@ -191,6 +194,7 @@ class Cell(PyTorchOperation):
...
@@ -191,6 +194,7 @@ class Cell(PyTorchOperation):
framework
framework
No real usage. Exists for compatibility with base class.
No real usage. Exists for compatibility with base class.
"""
"""
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{}):
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{}):
self
.
type
=
'_cell'
self
.
type
=
'_cell'
self
.
cell_name
=
cell_name
self
.
cell_name
=
cell_name
...
@@ -207,6 +211,7 @@ class _IOPseudoOperation(Operation):
...
@@ -207,6 +211,7 @@ class _IOPseudoOperation(Operation):
The benefit is that users no longer need to verify `Node.operation is not None`,
The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking.
especially in static type checking.
"""
"""
def
__init__
(
self
,
type_name
:
str
,
io_names
:
List
=
None
):
def
__init__
(
self
,
type_name
:
str
,
io_names
:
List
=
None
):
assert
type_name
.
startswith
(
'_'
)
assert
type_name
.
startswith
(
'_'
)
super
(
_IOPseudoOperation
,
self
).
__init__
(
type_name
,
{},
True
)
super
(
_IOPseudoOperation
,
self
).
__init__
(
type_name
,
{},
True
)
...
...
nni/retiarii/operation_def/tf_op_def.py
View file @
59cd3982
from
..operation
import
TensorFlowOperation
from
..operation
import
TensorFlowOperation
class
Conv2D
(
TensorFlowOperation
):
class
Conv2D
(
TensorFlowOperation
):
def
__init__
(
self
,
type_name
,
parameters
,
_internal
):
def
__init__
(
self
,
type_name
,
parameters
,
_internal
):
if
'padding'
not
in
parameters
:
if
'padding'
not
in
parameters
:
parameters
[
'padding'
]
=
'same'
parameters
[
'padding'
]
=
'same'
super
().
__init__
(
type_name
,
parameters
,
_internal
)
super
().
__init__
(
type_name
,
parameters
,
_internal
)
\ No newline at end of file
nni/retiarii/operation_def/torch_op_def.py
View file @
59cd3982
from
..operation
import
PyTorchOperation
from
..operation
import
PyTorchOperation
class
relu
(
PyTorchOperation
):
class
relu
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
def
to_init_code
(
self
,
field
):
return
''
return
''
...
@@ -17,6 +18,7 @@ class Flatten(PyTorchOperation):
...
@@ -17,6 +18,7 @@ class Flatten(PyTorchOperation):
assert
len
(
inputs
)
==
1
assert
len
(
inputs
)
==
1
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
0
]
}
.size(0), -1)'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
0
]
}
.size(0), -1)'
class
ToDevice
(
PyTorchOperation
):
class
ToDevice
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
def
to_init_code
(
self
,
field
):
return
''
return
''
...
...
nni/retiarii/strategies/strategy.py
View file @
59cd3982
import
abc
import
abc
from
typing
import
List
from
typing
import
List
from
..graph
import
Model
from
..mutator
import
Mutator
class
BaseStrategy
(
abc
.
ABC
):
class
BaseStrategy
(
abc
.
ABC
):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
run
(
self
,
base_model
:
'
Model
'
,
applied_mutators
:
List
[
'
Mutator
'
])
->
None
:
def
run
(
self
,
base_model
:
Model
,
applied_mutators
:
List
[
Mutator
])
->
None
:
pass
pass
nni/retiarii/strategies/tpe_strategy.py
View file @
59cd3982
import
json
import
logging
import
logging
import
random
import
os
from
..
import
Model
,
submit_models
,
wait_models
from
..
import
Sampler
,
submit_models
,
wait_models
from
..
import
Sampler
from
.strategy
import
BaseStrategy
from
.strategy
import
BaseStrategy
from
...algorithms.hpo.hyperopt_tuner.hyperopt_tuner
import
HyperoptTuner
from
...algorithms.hpo.hyperopt_tuner.hyperopt_tuner
import
HyperoptTuner
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
class
TPESampler
(
Sampler
):
class
TPESampler
(
Sampler
):
def
__init__
(
self
,
optimize_mode
=
'minimize'
):
def
__init__
(
self
,
optimize_mode
=
'minimize'
):
self
.
tpe_tuner
=
HyperoptTuner
(
'tpe'
,
optimize_mode
)
self
.
tpe_tuner
=
HyperoptTuner
(
'tpe'
,
optimize_mode
)
...
@@ -37,6 +34,7 @@ class TPESampler(Sampler):
...
@@ -37,6 +34,7 @@ class TPESampler(Sampler):
self
.
index
+=
1
self
.
index
+=
1
return
chosen
return
chosen
class
TPEStrategy
(
BaseStrategy
):
class
TPEStrategy
(
BaseStrategy
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
tpe_sampler
=
TPESampler
()
self
.
tpe_sampler
=
TPESampler
()
...
@@ -55,7 +53,7 @@ class TPEStrategy(BaseStrategy):
...
@@ -55,7 +53,7 @@ class TPEStrategy(BaseStrategy):
while
True
:
while
True
:
model
=
base_model
model
=
base_model
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'mutators:
{}'
.
format
(
applied_mutators
))
_logger
.
info
(
'mutators:
%s'
,
str
(
applied_mutators
))
self
.
tpe_sampler
.
generate_samples
(
self
.
model_id
)
self
.
tpe_sampler
.
generate_samples
(
self
.
model_id
)
for
mutator
in
applied_mutators
:
for
mutator
in
applied_mutators
:
_logger
.
info
(
'mutate model...'
)
_logger
.
info
(
'mutate model...'
)
...
@@ -66,6 +64,6 @@ class TPEStrategy(BaseStrategy):
...
@@ -66,6 +64,6 @@ class TPEStrategy(BaseStrategy):
wait_models
(
model
)
wait_models
(
model
)
self
.
tpe_sampler
.
receive_result
(
self
.
model_id
,
model
.
metric
)
self
.
tpe_sampler
.
receive_result
(
self
.
model_id
,
model
.
metric
)
self
.
model_id
+=
1
self
.
model_id
+=
1
_logger
.
info
(
'Strategy says:'
,
model
.
metric
)
_logger
.
info
(
'Strategy says:
%s
'
,
model
.
metric
)
except
Exception
as
e
:
except
Exception
:
_logger
.
error
(
logging
.
exception
(
'message'
))
_logger
.
error
(
logging
.
exception
(
'message'
))
nni/retiarii/trainer/interface.py
View file @
59cd3982
import
abc
import
abc
import
inspect
from
typing
import
Any
from
typing
import
*
class
BaseTrainer
(
abc
.
ABC
):
class
BaseTrainer
(
abc
.
ABC
):
...
...
nni/retiarii/trainer/pytorch/base.py
View file @
59cd3982
import
abc
from
typing
import
Any
,
List
,
Dict
,
Tuple
from
typing
import
*
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -42,6 +41,7 @@ def get_default_transform(dataset: str) -> Any:
...
@@ -42,6 +41,7 @@ def get_default_transform(dataset: str) -> Any:
# unsupported dataset, return None
# unsupported dataset, return None
return
None
return
None
@
register_trainer
()
@
register_trainer
()
class
PyTorchImageClassificationTrainer
(
BaseTrainer
):
class
PyTorchImageClassificationTrainer
(
BaseTrainer
):
"""
"""
...
@@ -94,7 +94,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
...
@@ -94,7 +94,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
self
.
_dataloader
=
DataLoader
(
self
.
_dataloader
=
DataLoader
(
self
.
_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_dataset
,
**
(
dataloader_kwargs
or
{}))
def
_accuracy
(
self
,
input
,
target
):
def
_accuracy
(
self
,
input
,
target
):
# pylint: disable=redefined-builtin
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
correct
=
predict
.
eq
(
target
.
data
).
cpu
().
sum
().
item
()
correct
=
predict
.
eq
(
target
.
data
).
cpu
().
sum
().
item
()
return
correct
/
input
.
size
(
0
)
return
correct
/
input
.
size
(
0
)
...
@@ -176,7 +176,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -176,7 +176,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
dataloader
=
DataLoader
(
dataset
,
**
(
dataloader_kwargs
or
{}))
dataloader
=
DataLoader
(
dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_datasets
.
append
(
dataset
)
self
.
_datasets
.
append
(
dataset
)
self
.
_dataloaders
.
append
(
dataloader
)
self
.
_dataloaders
.
append
(
dataloader
)
if
m
[
'use_output'
]:
if
m
[
'use_output'
]:
optimizer_cls
=
m
[
'optimizer_cls'
]
optimizer_cls
=
m
[
'optimizer_cls'
]
optimizer_kwargs
=
m
[
'optimizer_kwargs'
]
optimizer_kwargs
=
m
[
'optimizer_kwargs'
]
...
@@ -186,7 +186,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -186,7 +186,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
name_prefix
=
'_'
.
join
(
name
.
split
(
'_'
)[:
2
])
name_prefix
=
'_'
.
join
(
name
.
split
(
'_'
)[:
2
])
if
m_header
==
name_prefix
:
if
m_header
==
name_prefix
:
one_model_params
.
append
(
param
)
one_model_params
.
append
(
param
)
optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
one_model_params
,
**
(
optimizer_kwargs
or
{}))
optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
one_model_params
,
**
(
optimizer_kwargs
or
{}))
self
.
_optimizers
.
append
(
optimizer
)
self
.
_optimizers
.
append
(
optimizer
)
...
@@ -206,7 +206,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -206,7 +206,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
,
f
'cuda:
{
idx
}
'
)
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
,
f
'cuda:
{
idx
}
'
)
xs
.
append
(
x
)
xs
.
append
(
x
)
ys
.
append
(
y
)
ys
.
append
(
y
)
y_hats
=
self
.
multi_model
(
*
xs
)
y_hats
=
self
.
multi_model
(
*
xs
)
if
len
(
ys
)
!=
len
(
xs
):
if
len
(
ys
)
!=
len
(
xs
):
raise
ValueError
(
'len(ys) should be equal to len(xs)'
)
raise
ValueError
(
'len(ys) should be equal to len(xs)'
)
...
@@ -230,13 +230,12 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -230,13 +230,12 @@ class PyTorchMultiModelTrainer(BaseTrainer):
if
self
.
max_steps
and
batch_idx
>=
self
.
max_steps
:
if
self
.
max_steps
and
batch_idx
>=
self
.
max_steps
:
return
return
def
training_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
def
training_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
)
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
y_hat
=
self
.
model
(
x
)
return
self
.
training_step_after_model
(
x
,
y
,
y_hat
)
return
self
.
training_step_after_model
(
x
,
y
,
y_hat
)
def
training_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
def
training_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
x
,
y
=
batch
x
,
y
=
batch
if
device
:
if
device
:
x
,
y
=
x
.
cuda
(
torch
.
device
(
device
)),
y
.
cuda
(
torch
.
device
(
device
))
x
,
y
=
x
.
cuda
(
torch
.
device
(
device
)),
y
.
cuda
(
torch
.
device
(
device
))
...
@@ -259,4 +258,4 @@ class PyTorchMultiModelTrainer(BaseTrainer):
...
@@ -259,4 +258,4 @@ class PyTorchMultiModelTrainer(BaseTrainer):
def
validation_step_after_model
(
self
,
x
,
y
,
y_hat
):
def
validation_step_after_model
(
self
,
x
,
y
,
y_hat
):
acc
=
self
.
_accuracy
(
y_hat
,
y
)
acc
=
self
.
_accuracy
(
y_hat
,
y
)
return
{
'val_acc'
:
acc
}
return
{
'val_acc'
:
acc
}
\ No newline at end of file
nni/retiarii/trainer/pytorch/darts.py
View file @
59cd3982
...
@@ -6,7 +6,6 @@ import logging
...
@@ -6,7 +6,6 @@ import logging
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.nas.pytorch.mutables
import
LayerChoice
from
..interface
import
BaseOneShotTrainer
from
..interface
import
BaseOneShotTrainer
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
...
...
nni/retiarii/trainer/pytorch/enas.py
View file @
59cd3982
...
@@ -86,8 +86,8 @@ class ReinforceController(nn.Module):
...
@@ -86,8 +86,8 @@ class ReinforceController(nn.Module):
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
tensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
tensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
# pylint: disable=not-callable
requires_grad
=
False
)
# pylint: disable=not-callable
requires_grad
=
False
)
assert
entropy_reduction
in
[
'sum'
,
'mean'
],
'Entropy reduction must be one of sum and mean.'
assert
entropy_reduction
in
[
'sum'
,
'mean'
],
'Entropy reduction must be one of sum and mean.'
self
.
entropy_reduction
=
torch
.
sum
if
entropy_reduction
==
'sum'
else
torch
.
mean
self
.
entropy_reduction
=
torch
.
sum
if
entropy_reduction
==
'sum'
else
torch
.
mean
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
(
reduction
=
'none'
)
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
(
reduction
=
'none'
)
...
...
nni/retiarii/trainer/pytorch/random.py
View file @
59cd3982
...
@@ -16,7 +16,7 @@ _logger = logging.getLogger(__name__)
...
@@ -16,7 +16,7 @@ _logger = logging.getLogger(__name__)
def
_get_mask
(
sampled
,
total
):
def
_get_mask
(
sampled
,
total
):
multihot
=
[
i
==
sampled
or
(
isinstance
(
sampled
,
list
)
and
i
in
sampled
)
for
i
in
range
(
total
)]
multihot
=
[
i
==
sampled
or
(
isinstance
(
sampled
,
list
)
and
i
in
sampled
)
for
i
in
range
(
total
)]
return
torch
.
tensor
(
multihot
,
dtype
=
torch
.
bool
)
return
torch
.
tensor
(
multihot
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
class
PathSamplingLayerChoice
(
nn
.
Module
):
class
PathSamplingLayerChoice
(
nn
.
Module
):
...
@@ -44,9 +44,9 @@ class PathSamplingLayerChoice(nn.Module):
...
@@ -44,9 +44,9 @@ class PathSamplingLayerChoice(nn.Module):
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
assert
self
.
sampled
is
not
None
,
'At least one path needs to be sampled before fprop.'
assert
self
.
sampled
is
not
None
,
'At least one path needs to be sampled before fprop.'
if
isinstance
(
self
.
sampled
,
list
):
if
isinstance
(
self
.
sampled
,
list
):
return
sum
([
getattr
(
self
,
self
.
op_names
[
i
])(
*
args
,
**
kwargs
)
for
i
in
self
.
sampled
])
return
sum
([
getattr
(
self
,
self
.
op_names
[
i
])(
*
args
,
**
kwargs
)
for
i
in
self
.
sampled
])
# pylint: disable=not-an-iterable
else
:
else
:
return
getattr
(
self
,
self
.
op_names
[
self
.
sampled
])(
*
args
,
**
kwargs
)
return
getattr
(
self
,
self
.
op_names
[
self
.
sampled
])(
*
args
,
**
kwargs
)
# pylint: disable=invalid-sequence-index
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
op_names
)
return
len
(
self
.
op_names
)
...
@@ -76,7 +76,7 @@ class PathSamplingInputChoice(nn.Module):
...
@@ -76,7 +76,7 @@ class PathSamplingInputChoice(nn.Module):
def
forward
(
self
,
input_tensors
):
def
forward
(
self
,
input_tensors
):
if
isinstance
(
self
.
sampled
,
list
):
if
isinstance
(
self
.
sampled
,
list
):
return
sum
([
input_tensors
[
t
]
for
t
in
self
.
sampled
])
return
sum
([
input_tensors
[
t
]
for
t
in
self
.
sampled
])
# pylint: disable=not-an-iterable
else
:
else
:
return
input_tensors
[
self
.
sampled
]
return
input_tensors
[
self
.
sampled
]
...
...
nni/retiarii/trainer/pytorch/utils.py
View file @
59cd3982
...
@@ -123,13 +123,13 @@ class AverageMeter:
...
@@ -123,13 +123,13 @@ class AverageMeter:
return
fmtstr
.
format
(
**
self
.
__dict__
)
return
fmtstr
.
format
(
**
self
.
__dict__
)
def
_replace_module_with_type
(
root_module
,
init_fn
,
type
,
modules
):
def
_replace_module_with_type
(
root_module
,
init_fn
,
type
_name
,
modules
):
if
modules
is
None
:
if
modules
is
None
:
modules
=
[]
modules
=
[]
def
apply
(
m
):
def
apply
(
m
):
for
name
,
child
in
m
.
named_children
():
for
name
,
child
in
m
.
named_children
():
if
isinstance
(
child
,
type
):
if
isinstance
(
child
,
type
_name
):
setattr
(
m
,
name
,
init_fn
(
child
))
setattr
(
m
,
name
,
init_fn
(
child
))
modules
.
append
((
child
.
key
,
getattr
(
m
,
name
)))
modules
.
append
((
child
.
key
,
getattr
(
m
,
name
)))
else
:
else
:
...
...
nni/retiarii/utils.py
View file @
59cd3982
from
collections
import
defaultdict
import
inspect
import
inspect
from
collections
import
defaultdict
from
typing
import
Any
def
import_
(
target
:
str
,
allow_none
:
bool
=
False
)
->
'
Any
'
:
def
import_
(
target
:
str
,
allow_none
:
bool
=
False
)
->
Any
:
if
target
is
None
:
if
target
is
None
:
return
None
return
None
path
,
identifier
=
target
.
rsplit
(
'.'
,
1
)
path
,
identifier
=
target
.
rsplit
(
'.'
,
1
)
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
return
getattr
(
module
,
identifier
)
return
getattr
(
module
,
identifier
)
_records
=
{}
_records
=
{}
def
get_records
():
def
get_records
():
global
_records
global
_records
return
_records
return
_records
def
add_record
(
key
,
value
):
def
add_record
(
key
,
value
):
"""
"""
"""
"""
...
@@ -22,6 +27,7 @@ def add_record(key, value):
...
@@ -22,6 +27,7 @@ def add_record(key, value):
assert
key
not
in
_records
,
'{} already in _records'
.
format
(
key
)
assert
key
not
in
_records
,
'{} already in _records'
.
format
(
key
)
_records
[
key
]
=
value
_records
[
key
]
=
value
def
_register_module
(
original_class
):
def
_register_module
(
original_class
):
orig_init
=
original_class
.
__init__
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
...
@@ -31,14 +37,15 @@ def _register_module(original_class):
...
@@ -31,14 +37,15 @@ def _register_module(original_class):
full_args
=
{}
full_args
=
{}
full_args
.
update
(
kws
)
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
s
[
i
]
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
return
original_class
def
register_module
():
def
register_module
():
"""
"""
Register a module.
Register a module.
...
@@ -68,14 +75,15 @@ def _register_trainer(original_class):
...
@@ -68,14 +75,15 @@ def _register_trainer(original_class):
if
isinstance
(
args
[
i
],
Module
):
if
isinstance
(
args
[
i
],
Module
):
# ignore the base model object
# ignore the base model object
continue
continue
full_args
[
argname_list
[
i
]]
=
arg
s
[
i
]
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
return
original_class
def
register_trainer
():
def
register_trainer
():
def
_register
(
cls
):
def
_register
(
cls
):
m
=
_register_trainer
(
m
=
_register_trainer
(
...
@@ -84,8 +92,10 @@ def register_trainer():
...
@@ -84,8 +92,10 @@ def register_trainer():
return
_register
return
_register
_last_uid
=
defaultdict
(
int
)
_last_uid
=
defaultdict
(
int
)
def
uid
(
namespace
:
str
=
'default'
)
->
int
:
def
uid
(
namespace
:
str
=
'default'
)
->
int
:
_last_uid
[
namespace
]
+=
1
_last_uid
[
namespace
]
+=
1
return
_last_uid
[
namespace
]
return
_last_uid
[
namespace
]
pipelines/fast-test.yml
View file @
59cd3982
...
@@ -41,7 +41,7 @@ jobs:
...
@@ -41,7 +41,7 @@ jobs:
python3 -m pip install --upgrade pygments
python3 -m pip install --upgrade pygments
python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install --upgrade tensorflow
python3 -m pip install --upgrade tensorflow
python3 -m pip install --upgrade gym onnx peewee thop
python3 -m pip install --upgrade gym onnx peewee thop
graphviz
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx
sudo apt-get install swig -y
sudo apt-get install swig -y
python3 -m pip install -e .[SMAC,BOHB]
python3 -m pip install -e .[SMAC,BOHB]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment