Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
1cada380
Commit
1cada380
authored
Nov 18, 2019
by
Yuge Zhang
Committed by
QuanluZhang
Nov 18, 2019
Browse files
Extract base mutator/trainer and support ENAS micro search space (#1739)
parent
3ddab980
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
187 additions
and
126 deletions
+187
-126
src/sdk/pynni/nni/nas/pytorch/fixed.py
src/sdk/pynni/nni/nas/pytorch/fixed.py
+58
-0
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+37
-21
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+30
-85
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
+3
-3
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+59
-6
src/sdk/pynni/nni/nas/utils.py
src/sdk/pynni/nni/nas/utils.py
+0
-11
No files found.
src/sdk/pynni/nni/nas/pytorch/fixed.py
0 → 100644
View file @
1cada380
import
json
import
torch
from
nni.nas.pytorch.mutator
import
Mutator
class
FixedArchitecture
(
Mutator
):
def
__init__
(
self
,
model
,
fixed_arc
,
strict
=
True
):
"""
Initialize a fixed architecture mutator.
Parameters
----------
model: nn.Module
A mutable network.
fixed_arc: str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict: bool
Force everything that appears in `fixed_arc` to be used at least once.
"""
super
().
__init__
(
model
)
if
isinstance
(
fixed_arc
,
str
):
with
open
(
fixed_arc
,
"r"
)
as
f
:
fixed_arc
=
json
.
load
(
f
.
read
())
self
.
_fixed_arc
=
fixed_arc
self
.
_strict
=
strict
def
_encode_tensor
(
self
,
data
):
if
isinstance
(
data
,
list
):
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
else
:
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
)
# pylint: disable=not-callable
if
isinstance
(
data
,
dict
):
return
{
k
:
self
.
_encode_tensor
(
v
)
for
k
,
v
in
data
.
items
()}
return
data
def
before_pass
(
self
):
self
.
_unused_key
=
set
(
self
.
_fixed_arc
.
keys
())
def
after_pass
(
self
):
if
self
.
_strict
:
if
self
.
_unused_key
:
raise
ValueError
(
"{} are never used by the network. "
"Set strict=False if you want to disable this check."
.
format
(
self
.
_unused_key
))
def
_check_key
(
self
,
key
):
if
key
not
in
self
.
_fixed_arc
:
raise
ValueError
(
"
\"
{}
\"
is demanded by the network, but not found in saved architecture."
.
format
(
key
))
def
on_calc_layer_choice_mask
(
self
,
mutable
):
self
.
_check_key
(
mutable
.
key
)
return
self
.
_fixed_arc
[
mutable
.
key
]
def
on_calc_input_choice_mask
(
self
,
mutable
,
tags
):
self
.
_check_key
(
mutable
.
key
)
return
self
.
_fixed_arc
[
mutable
.
key
]
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
1cada380
...
@@ -3,7 +3,7 @@ import torch.nn as nn
...
@@ -3,7 +3,7 @@ import torch.nn as nn
from
nni.nas.utils
import
global_mutable_counting
from
nni.nas.utils
import
global_mutable_counting
class
PyTorch
Mutable
(
nn
.
Module
):
class
Mutable
(
nn
.
Module
):
"""
"""
Mutable is designed to function as a normal layer, with all necessary operators' weights.
Mutable is designed to function as a normal layer, with all necessary operators' weights.
States and weights of architectures should be included in mutator, instead of the layer itself.
States and weights of architectures should be included in mutator, instead of the layer itself.
...
@@ -24,15 +24,11 @@ class PyTorchMutable(nn.Module):
...
@@ -24,15 +24,11 @@ class PyTorchMutable(nn.Module):
self
.
_key
=
key
self
.
_key
=
key
else
:
else
:
self
.
_key
=
self
.
__class__
.
__name__
+
str
(
global_mutable_counting
())
self
.
_key
=
self
.
__class__
.
__name__
+
str
(
global_mutable_counting
())
self
.
name
=
self
.
key
self
.
init_hook
=
self
.
forward_hook
=
None
def
__deepcopy__
(
self
,
memodict
=
None
):
def
__deepcopy__
(
self
,
memodict
=
None
):
raise
NotImplementedError
(
"Deep copy doesn't work for mutables."
)
raise
NotImplementedError
(
"Deep copy doesn't work for mutables."
)
def
__enter__
(
self
):
self
.
_check_built
()
return
super
().
__enter__
()
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
self
.
_check_built
()
self
.
_check_built
()
return
super
().
__call__
(
*
args
,
**
kwargs
)
return
super
().
__call__
(
*
args
,
**
kwargs
)
...
@@ -47,8 +43,16 @@ class PyTorchMutable(nn.Module):
...
@@ -47,8 +43,16 @@ class PyTorchMutable(nn.Module):
def
key
(
self
):
def
key
(
self
):
return
self
.
_key
return
self
.
_key
@
property
def
name
(
self
):
return
self
.
_name
if
hasattr
(
self
,
"_name"
)
else
"_key"
@
name
.
setter
def
name
(
self
,
name
):
self
.
_name
=
name
def
similar
(
self
,
other
):
def
similar
(
self
,
other
):
return
self
==
other
return
type
(
self
)
==
type
(
other
)
def
_check_built
(
self
):
def
_check_built
(
self
):
if
not
hasattr
(
self
,
"mutator"
):
if
not
hasattr
(
self
,
"mutator"
):
...
@@ -56,8 +60,11 @@ class PyTorchMutable(nn.Module):
...
@@ -56,8 +60,11 @@ class PyTorchMutable(nn.Module):
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
def
__repr__
(
self
):
return
"{} ({})"
.
format
(
self
.
name
,
self
.
key
)
class
MutableScope
(
PyTorch
Mutable
):
class
MutableScope
(
Mutable
):
"""
"""
Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope
Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope
is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch
is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch
...
@@ -67,14 +74,18 @@ class MutableScope(PyTorchMutable):
...
@@ -67,14 +74,18 @@ class MutableScope(PyTorchMutable):
def
__init__
(
self
,
key
):
def
__init__
(
self
,
key
):
super
().
__init__
(
key
=
key
)
super
().
__init__
(
key
=
key
)
def
__enter__
(
self
):
def
build
(
self
):
self
.
mutator
.
enter
_mutable_scope
(
self
)
self
.
mutator
.
on_init
_mutable_scope
(
self
)
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
try
:
self
.
mutator
.
enter_mutable_scope
(
self
)
return
super
().
__call__
(
*
args
,
**
kwargs
)
finally
:
self
.
mutator
.
exit_mutable_scope
(
self
)
self
.
mutator
.
exit_mutable_scope
(
self
)
class
LayerChoice
(
PyTorch
Mutable
):
class
LayerChoice
(
Mutable
):
def
__init__
(
self
,
op_candidates
,
reduction
=
"mean"
,
return_mask
=
False
,
key
=
None
):
def
__init__
(
self
,
op_candidates
,
reduction
=
"mean"
,
return_mask
=
False
,
key
=
None
):
super
().
__init__
(
key
=
key
)
super
().
__init__
(
key
=
key
)
self
.
length
=
len
(
op_candidates
)
self
.
length
=
len
(
op_candidates
)
...
@@ -83,10 +94,10 @@ class LayerChoice(PyTorchMutable):
...
@@ -83,10 +94,10 @@ class LayerChoice(PyTorchMutable):
self
.
return_mask
=
return_mask
self
.
return_mask
=
return_mask
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
length
return
len
(
self
.
choices
)
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
out
,
mask
=
self
.
mutator
.
on_forward
(
self
,
*
inputs
)
out
,
mask
=
self
.
mutator
.
on_forward
_layer_choice
(
self
,
*
inputs
)
if
self
.
return_mask
:
if
self
.
return_mask
:
return
out
,
mask
return
out
,
mask
return
out
return
out
...
@@ -95,7 +106,7 @@ class LayerChoice(PyTorchMutable):
...
@@ -95,7 +106,7 @@ class LayerChoice(PyTorchMutable):
return
type
(
self
)
==
type
(
other
)
and
self
.
length
==
other
.
length
return
type
(
self
)
==
type
(
other
)
and
self
.
length
==
other
.
length
class
InputChoice
(
PyTorch
Mutable
):
class
InputChoice
(
Mutable
):
def
__init__
(
self
,
n_candidates
,
n_selected
=
None
,
reduction
=
"mean"
,
return_mask
=
False
,
key
=
None
):
def
__init__
(
self
,
n_candidates
,
n_selected
=
None
,
reduction
=
"mean"
,
return_mask
=
False
,
key
=
None
):
super
().
__init__
(
key
=
key
)
super
().
__init__
(
key
=
key
)
assert
n_candidates
>
0
,
"Number of candidates must be greater than 0."
assert
n_candidates
>
0
,
"Number of candidates must be greater than 0."
...
@@ -104,12 +115,17 @@ class InputChoice(PyTorchMutable):
...
@@ -104,12 +115,17 @@ class InputChoice(PyTorchMutable):
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
self
.
return_mask
=
return_mask
def
forward
(
self
,
optional_inputs
,
semantic_labels
=
None
):
def
build
(
self
):
self
.
mutator
.
on_init_input_choice
(
self
)
def
forward
(
self
,
optional_inputs
,
tags
=
None
):
assert
len
(
optional_inputs
)
==
self
.
n_candidates
,
\
assert
len
(
optional_inputs
)
==
self
.
n_candidates
,
\
"Length of the input list must be equal to number of candidates."
"Length of the input list must be equal to number of candidates."
if
semantic_labels
is
None
:
if
tags
is
None
:
semantic_labels
=
[
"default_label"
]
*
self
.
n_candidates
tags
=
[
""
]
*
self
.
n_candidates
out
,
mask
=
self
.
mutator
.
on_forward
(
self
,
optional_inputs
,
semantic_labels
)
else
:
assert
len
(
tags
)
==
self
.
n_candidates
,
"Length of tags must be equal to number of candidates."
out
,
mask
=
self
.
mutator
.
on_forward_input_choice
(
self
,
optional_inputs
,
tags
)
if
self
.
return_mask
:
if
self
.
return_mask
:
return
out
,
mask
return
out
,
mask
return
out
return
out
...
...
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
1cada380
import
logging
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.nas.pytorch.mutables
import
PyTorchMutable
from
nni.nas.pytorch.base_mutator
import
BaseMutator
from
nni.nas.utils
import
to_snake_case
logger
=
logging
.
getLogger
(
__name__
)
class
Mutator
(
BaseMutator
,
nn
.
Module
):
class
PyTorchMutator
(
nn
.
Module
):
def
export
(
self
):
def
__init__
(
self
,
model
):
if
self
.
_in_forward_pass
:
super
().
__init__
()
raise
RuntimeError
(
"Still in forward pass. Exporting might induce incompleteness."
)
self
.
before_build
(
model
)
if
not
self
.
_cache
:
self
.
parse_search_space
(
model
)
raise
RuntimeError
(
"No running history found. You need to call your model at least once before exporting. "
self
.
after_build
(
model
)
"You might also want to check if there are no valid mutables in your model."
)
return
self
.
_cache
def
before_build
(
self
,
model
):
pass
def
after_build
(
self
,
model
):
pass
def
named_mutables
(
self
,
model
):
# if distinct is true, the method will filter out those with duplicated keys
key2module
=
dict
()
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
PyTorchMutable
):
distinct
=
False
if
module
.
key
in
key2module
:
assert
key2module
[
module
.
key
].
similar
(
module
),
\
"Mutable
\"
{}
\"
that share the same key must be similar to each other"
.
format
(
module
.
key
)
else
:
distinct
=
True
key2module
[
module
.
key
]
=
module
yield
name
,
module
,
distinct
def
__setattr__
(
self
,
key
,
value
):
if
key
in
[
"model"
,
"net"
,
"network"
]:
logger
.
warning
(
"Think twice if you are including the network into mutator."
)
return
super
().
__setattr__
(
key
,
value
)
def
parse_search_space
(
self
,
model
):
for
name
,
mutable
,
distinct
in
self
.
named_mutables
(
model
):
mutable
.
name
=
name
mutable
.
set_mutator
(
self
)
if
not
distinct
:
continue
init_method_name
=
"on_init_{}"
.
format
(
to_snake_case
(
mutable
.
__class__
.
__name__
))
if
hasattr
(
self
,
init_method_name
)
and
callable
(
getattr
(
self
,
init_method_name
)):
getattr
(
self
,
init_method_name
)(
mutable
)
else
:
# fallback to general init
self
.
on_init_general
(
mutable
)
def
on_init_general
(
self
,
mutable
):
pass
@
contextmanager
@
contextmanager
def
forward_pass
(
self
):
def
forward_pass
(
self
):
self
.
_in_forward_pass
=
True
self
.
_cache
=
dict
()
self
.
before_pass
()
self
.
before_pass
()
try
:
try
:
yield
self
yield
self
finally
:
finally
:
self
.
after_pass
()
self
.
after_pass
()
def
before_pass
(
self
):
self
.
_in_forward_pass
=
True
self
.
_cache
=
dict
()
def
after_pass
(
self
):
self
.
_in_forward_pass
=
False
self
.
_in_forward_pass
=
False
def
enter_mutable_scope
(
self
,
mutable_scope
):
def
before_pass
(
self
):
pass
pass
def
exit_mutable_scope
(
self
,
mutable_scope
):
def
after_pass
(
self
):
pass
pass
def
forward
(
self
,
*
inputs
):
def
_check_in_forward_pass
(
self
):
raise
NotImplementedError
(
"Mutator is not forward-able"
)
def
on_forward
(
self
,
mutable
,
*
inputs
):
"""Callback on forwarding a mutable"""
if
not
hasattr
(
self
,
"_in_forward_pass"
)
or
not
self
.
_in_forward_pass
:
if
not
hasattr
(
self
,
"_in_forward_pass"
)
or
not
self
.
_in_forward_pass
:
raise
ValueError
(
"Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call "
raise
ValueError
(
"Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call "
"super().before_pass() and after_pass() in your override method?"
)
"super().before_pass() and after_pass() in your override method?"
)
forward_method_name
=
"on_forward_{}"
.
format
(
to_snake_case
(
mutable
.
__class__
.
__name__
))
if
hasattr
(
self
,
forward_method_name
)
and
callable
(
getattr
(
self
,
forward_method_name
)):
return
getattr
(
self
,
forward_method_name
)(
mutable
,
*
inputs
)
else
:
# fallback to general forward
return
self
.
on_forward_general
(
mutable
,
*
inputs
)
def
on_forward_general
(
self
,
mutable
,
*
inputs
):
raise
NotImplementedError
(
"Forward has to be implemented"
)
def
on_forward_layer_choice
(
self
,
mutable
,
*
inputs
):
def
on_forward_layer_choice
(
self
,
mutable
,
*
inputs
):
"""
"""
Callback of layer choice forward. Override if you are an advanced user.
Callback of layer choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy spe
i
cified
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
Parameters
Parameters
...
@@ -111,33 +52,38 @@ class PyTorchMutator(nn.Module):
...
@@ -111,33 +52,38 @@ class PyTorchMutator(nn.Module):
Returns
Returns
-------
-------
torch.Tensor
tuple of torch.Tensor and
torch.Tensor
"""
"""
self
.
_check_in_forward_pass
()
def
_map_fn
(
op
,
*
inputs
):
def
_map_fn
(
op
,
*
inputs
):
return
op
(
*
inputs
)
return
op
(
*
inputs
)
mask
=
self
.
_cache
.
setdefault
(
mutable
.
key
,
self
.
on_calc_layer_choice_mask
(
mutable
))
mask
=
self
.
_cache
.
setdefault
(
mutable
.
key
,
self
.
on_calc_layer_choice_mask
(
mutable
))
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
*
inputs
)
for
choice
in
mutable
.
choices
],
mask
)
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
*
inputs
)
for
choice
in
mutable
.
choices
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
,
semantic_label
s
):
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
,
tag
s
):
"""
"""
Callback of input choice forward. Override if you are an advanced user.
Callback of input choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_input_choice_mask` with `
semantic_label
s`
On default, this method calls :meth:`on_calc_input_choice_mask` with `
tag
s`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy spe
i
cified in `mutable.reduction`. It will also cache the
the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the
mask with corresponding `mutable.key`.
mask with corresponding `mutable.key`.
Parameters
Parameters
----------
----------
mutable: InputChoice
mutable: InputChoice
inputs: list of torch.Tensor
tensor_list: list of torch.Tensor
tags: list of string
Returns
Returns
-------
-------
torch.Tensor
tuple of torch.Tensor and
torch.Tensor
"""
"""
mask
=
self
.
_cache
.
setdefault
(
mutable
.
key
,
self
.
on_calc_input_choice_mask
(
mutable
,
semantic_labels
))
self
.
_check_in_forward_pass
()
out
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,
)
for
t
in
tensor_list
],
mask
)
mask
=
self
.
_cache
.
setdefault
(
mutable
.
key
,
self
.
on_calc_input_choice_mask
(
mutable
,
tags
))
out
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,)
for
t
in
tensor_list
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_calc_layer_choice_mask
(
self
,
mutable
):
def
on_calc_layer_choice_mask
(
self
,
mutable
):
...
@@ -157,7 +103,7 @@ class PyTorchMutator(nn.Module):
...
@@ -157,7 +103,7 @@ class PyTorchMutator(nn.Module):
"""
"""
raise
NotImplementedError
(
"Layer choice mask calculation must be implemented"
)
raise
NotImplementedError
(
"Layer choice mask calculation must be implemented"
)
def
on_calc_input_choice_mask
(
self
,
mutable
,
semantic_label
s
):
def
on_calc_input_choice_mask
(
self
,
mutable
,
tag
s
):
"""
"""
Recommended to override. Calculate a mask tensor for a input choice.
Recommended to override. Calculate a mask tensor for a input choice.
...
@@ -165,7 +111,7 @@ class PyTorchMutator(nn.Module):
...
@@ -165,7 +111,7 @@ class PyTorchMutator(nn.Module):
----------
----------
mutable: InputChoice
mutable: InputChoice
Corresponding input choice object.
Corresponding input choice object.
semantic_label
s: list of string
tag
s: list of string
The name of labels of input tensors given by user. Usually it's a
The name of labels of input tensors given by user. Usually it's a
:class:`~nni.nas.pytorch.mutables.MutableScope` marked by user.
:class:`~nni.nas.pytorch.mutables.MutableScope` marked by user.
...
@@ -179,7 +125,6 @@ class PyTorchMutator(nn.Module):
...
@@ -179,7 +125,6 @@ class PyTorchMutator(nn.Module):
def
_select_with_mask
(
self
,
map_fn
,
candidates
,
mask
):
def
_select_with_mask
(
self
,
map_fn
,
candidates
,
mask
):
if
"BoolTensor"
in
mask
.
type
():
if
"BoolTensor"
in
mask
.
type
():
# print(candidates[0], len(mask))
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
elif
"FloatTensor"
in
mask
.
type
():
elif
"FloatTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)]
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)]
...
...
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
View file @
1cada380
...
@@ -33,13 +33,13 @@ class PdartsTrainer(Trainer):
...
@@ -33,13 +33,13 @@ class PdartsTrainer(Trainer):
for
epoch
in
range
(
self
.
pdarts_epoch
):
for
epoch
in
range
(
self
.
pdarts_epoch
):
layers
=
self
.
layers
+
self
.
pdarts_num_layers
[
epoch
]
layers
=
self
.
layers
+
self
.
pdarts_num_layers
[
epoch
]
model
,
loss
,
model_optim
,
lr_scheduler
=
self
.
model_creator
(
model
,
loss
,
model_optim
,
_
=
self
.
model_creator
(
layers
,
n_nodes
)
layers
,
n_nodes
)
mutator
=
PdartsMutator
(
mutator
=
PdartsMutator
(
model
,
epoch
,
self
.
pdarts_num_to_drop
,
switches
)
model
,
epoch
,
self
.
pdarts_num_to_drop
,
switches
)
self
.
trainer
=
DartsTrainer
(
model
,
loss
=
loss
,
model_
optim
=
model_optim
,
self
.
trainer
=
DartsTrainer
(
model
,
loss
=
loss
,
optim
izer
=
model_optim
,
lr_scheduler
=
lr_scheduler
,
mutator
=
mutator
,
**
self
.
darts_parameters
)
mutator
=
mutator
,
**
self
.
darts_parameters
)
print
(
"start pdrats training %s..."
%
epoch
)
print
(
"start pdrats training %s..."
%
epoch
)
self
.
trainer
.
train
()
self
.
trainer
.
train
()
...
...
src/sdk/pynni/nni/nas/pytorch/trainer.py
View file @
1cada380
from
abc
import
ABC
,
abstractmethod
from
abc
import
abstractmethod
import
torch
class
Trainer
(
ABC
):
from
.base_trainer
import
BaseTrainer
class
Trainer
(
BaseTrainer
):
def
__init__
(
self
,
model
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
mutator
,
callbacks
):
self
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
if
device
is
None
else
device
self
.
model
=
model
self
.
loss
=
loss
self
.
metrics
=
metrics
self
.
optimizer
=
optimizer
self
.
mutator
=
mutator
self
.
model
.
to
(
self
.
device
)
self
.
loss
.
to
(
self
.
device
)
self
.
mutator
.
to
(
self
.
device
)
self
.
num_epochs
=
num_epochs
self
.
dataset_train
=
dataset_train
self
.
dataset_valid
=
dataset_valid
self
.
batch_size
=
batch_size
self
.
workers
=
workers
self
.
log_frequency
=
log_frequency
self
.
callbacks
=
callbacks
if
callbacks
is
not
None
else
[]
for
callback
in
self
.
callbacks
:
callback
.
build
(
self
.
model
,
self
.
mutator
,
self
)
@
abstractmethod
@
abstractmethod
def
train
(
self
):
def
train
_one_epoch
(
self
,
epoch
):
raise
NotImplementedError
pass
@
abstractmethod
@
abstractmethod
def
export
(
self
):
def
validate_one_epoch
(
self
,
epoch
):
raise
NotImplementedError
pass
def
_train
(
self
,
validate
):
for
epoch
in
range
(
self
.
num_epochs
):
for
callback
in
self
.
callbacks
:
callback
.
on_epoch_begin
(
epoch
)
# training
print
(
"Epoch {} Training"
.
format
(
epoch
))
self
.
train_one_epoch
(
epoch
)
if
validate
:
# validation
print
(
"Epoch {} Validating"
.
format
(
epoch
))
self
.
validate_one_epoch
(
epoch
)
for
callback
in
self
.
callbacks
:
callback
.
on_epoch_end
(
epoch
)
def
train_and_validate
(
self
):
self
.
_train
(
True
)
def
train
(
self
):
self
.
_train
(
False
)
def
validate
(
self
):
self
.
validate_one_epoch
(
-
1
)
src/sdk/pynni/nni/nas/utils.py
View file @
1cada380
import
re
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
torch
_counter
=
0
_counter
=
0
...
@@ -12,14 +9,6 @@ def global_mutable_counting():
...
@@ -12,14 +9,6 @@ def global_mutable_counting():
return
_counter
return
_counter
def
to_snake_case
(
camel_case
):
return
re
.
sub
(
'(?!^)([A-Z]+)'
,
r
'_\1'
,
camel_case
).
lower
()
def
auto_device
():
return
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
class
AverageMeterGroup
(
object
):
class
AverageMeterGroup
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment