Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
70cee7d8
Unverified
Commit
70cee7d8
authored
Jan 06, 2020
by
Yuge Zhang
Committed by
GitHub
Jan 06, 2020
Browse files
TextNAS without retrain (#1890)
parent
eb39749f
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
957 additions
and
40 deletions
+957
-40
examples/nas/textnas/README.md
examples/nas/textnas/README.md
+45
-0
examples/nas/textnas/dataloader.py
examples/nas/textnas/dataloader.py
+334
-0
examples/nas/textnas/model.py
examples/nas/textnas/model.py
+108
-0
examples/nas/textnas/ops.py
examples/nas/textnas/ops.py
+205
-0
examples/nas/textnas/search.py
examples/nas/textnas/search.py
+89
-0
examples/nas/textnas/utils.py
examples/nas/textnas/utils.py
+67
-0
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
+16
-5
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
+65
-31
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+1
-1
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+2
-2
src/sdk/pynni/nni/nas/pytorch/utils.py
src/sdk/pynni/nni/nas/pytorch/utils.py
+25
-1
No files found.
examples/nas/textnas/README.md
0 → 100644
View file @
70cee7d8
# TextNAS: A Neural Architecture Search Space tailored for Text Representation
TextNAS by MSRA. Official Release.
[
Paper link
](
https://arxiv.org/abs/1912.10729
)
## Preparation
Prepare the word vectors and SST dataset, and organize them in data directory as shown below:
```
textnas
├── data
│ ├── sst
│ │ └── trees
│ │ ├── dev.txt
│ │ ├── test.txt
│ │ └── train.txt
│ └── glove.840B.300d.txt
├── dataloader.py
├── model.py
├── ops.py
├── README.md
├── search.py
└── utils.py
```
The following link might be helpful for finding and downloading the corresponding dataset:
*
[
GloVe: Global Vectors for Word Representation
](
https://nlp.stanford.edu/projects/glove/
)
*
[
Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank
](
https://nlp.stanford.edu/sentiment/
)
## Search
```
python search.py
```
After each search epoch, 10 sampled architectures will be tested directly. Their performances are expected to be 40% - 42% after 10 epochs.
By default, 20 sampled architectures will be exported into
`checkpoints`
directory for next step.
## Retrain
Not ready.
examples/nas/textnas/dataloader.py
0 → 100644
View file @
70cee7d8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
os
import
pickle
from
collections
import
Counter
import
numpy
as
np
import
torch
from
torch.utils
import
data
logger
=
logging
.
getLogger
(
"nni.textnas"
)
class
PTBTree
:
WORD_TO_WORD_MAPPING
=
{
"{"
:
"-LCB-"
,
"}"
:
"-RCB-"
}
def
__init__
(
self
):
self
.
subtrees
=
[]
self
.
word
=
None
self
.
label
=
""
self
.
parent
=
None
self
.
span
=
(
-
1
,
-
1
)
self
.
word_vector
=
None
# HOS, store dx1 RNN word vector
self
.
prediction
=
None
# HOS, store Kx1 prediction vector
def
is_leaf
(
self
):
return
len
(
self
.
subtrees
)
==
0
def
set_by_text
(
self
,
text
,
pos
=
0
,
left
=
0
):
depth
=
0
right
=
left
for
i
in
range
(
pos
+
1
,
len
(
text
)):
char
=
text
[
i
]
# update the depth
if
char
==
"("
:
depth
+=
1
if
depth
==
1
:
subtree
=
PTBTree
()
subtree
.
parent
=
self
subtree
.
set_by_text
(
text
,
i
,
right
)
right
=
subtree
.
span
[
1
]
self
.
span
=
(
left
,
right
)
self
.
subtrees
.
append
(
subtree
)
elif
char
==
")"
:
depth
-=
1
if
len
(
self
.
subtrees
)
==
0
:
pos
=
i
for
j
in
range
(
i
,
0
,
-
1
):
if
text
[
j
]
==
" "
:
pos
=
j
break
self
.
word
=
text
[
pos
+
1
:
i
]
self
.
span
=
(
left
,
left
+
1
)
# we've reached the end of the category that is the root of this subtree
if
depth
==
0
and
char
==
" "
and
self
.
label
==
""
:
self
.
label
=
text
[
pos
+
1
:
i
]
# we've reached the end of the scope for this bracket
if
depth
<
0
:
break
# Fix some issues with variation in output, and one error in the treebank
# for a word with a punctuation POS
self
.
standardise_node
()
def
standardise_node
(
self
):
if
self
.
word
in
self
.
WORD_TO_WORD_MAPPING
:
self
.
word
=
self
.
WORD_TO_WORD_MAPPING
[
self
.
word
]
def
__repr__
(
self
,
single_line
=
True
,
depth
=
0
):
ans
=
""
if
not
single_line
and
depth
>
0
:
ans
=
"
\n
"
+
depth
*
"
\t
"
ans
+=
"("
+
self
.
label
if
self
.
word
is
not
None
:
ans
+=
" "
+
self
.
word
for
subtree
in
self
.
subtrees
:
if
single_line
:
ans
+=
" "
ans
+=
subtree
.
__repr__
(
single_line
,
depth
+
1
)
ans
+=
")"
return
ans
def
read_tree
(
source
):
cur_text
=
[]
depth
=
0
while
True
:
line
=
source
.
readline
()
# Check if we are out of input
if
line
==
""
:
return
None
# strip whitespace and only use if this contains something
line
=
line
.
strip
()
if
line
==
""
:
continue
cur_text
.
append
(
line
)
# Update depth
for
char
in
line
:
if
char
==
"("
:
depth
+=
1
elif
char
==
")"
:
depth
-=
1
# At depth 0 we have a complete tree
if
depth
==
0
:
tree
=
PTBTree
()
tree
.
set_by_text
(
" "
.
join
(
cur_text
))
return
tree
return
None
def
read_trees
(
source
,
max_sents
=-
1
):
with
open
(
source
)
as
fp
:
trees
=
[]
while
True
:
tree
=
read_tree
(
fp
)
if
tree
is
None
:
break
trees
.
append
(
tree
)
if
len
(
trees
)
>=
max_sents
>
0
:
break
return
trees
class
SSTDataset
(
data
.
Dataset
):
def
__init__
(
self
,
sents
,
mask
,
labels
):
self
.
sents
=
sents
self
.
labels
=
labels
self
.
mask
=
mask
def
__getitem__
(
self
,
index
):
return
(
self
.
sents
[
index
],
self
.
mask
[
index
]),
self
.
labels
[
index
]
def
__len__
(
self
):
return
len
(
self
.
sents
)
def
sst_get_id_input
(
content
,
word_id_dict
,
max_input_length
):
words
=
content
.
split
(
" "
)
sentence
=
[
word_id_dict
[
"<pad>"
]]
*
max_input_length
mask
=
[
0
]
*
max_input_length
unknown
=
word_id_dict
[
"<unknown>"
]
for
i
,
word
in
enumerate
(
words
[:
max_input_length
]):
sentence
[
i
]
=
word_id_dict
.
get
(
word
,
unknown
)
mask
[
i
]
=
1
return
sentence
,
mask
def
sst_get_phrases
(
trees
,
sample_ratio
=
1.0
,
is_binary
=
False
,
only_sentence
=
False
):
all_phrases
=
[]
for
tree
in
trees
:
if
only_sentence
:
sentence
=
get_sentence_by_tree
(
tree
)
label
=
int
(
tree
.
label
)
pair
=
(
sentence
,
label
)
all_phrases
.
append
(
pair
)
else
:
phrases
=
get_phrases_by_tree
(
tree
)
sentence
=
get_sentence_by_tree
(
tree
)
pair
=
(
sentence
,
int
(
tree
.
label
))
all_phrases
.
append
(
pair
)
all_phrases
+=
phrases
if
sample_ratio
<
1.
:
np
.
random
.
shuffle
(
all_phrases
)
result_phrases
=
[]
for
pair
in
all_phrases
:
if
is_binary
:
phrase
,
label
=
pair
if
label
<=
1
:
pair
=
(
phrase
,
0
)
elif
label
>=
3
:
pair
=
(
phrase
,
1
)
else
:
continue
if
sample_ratio
==
1.
:
result_phrases
.
append
(
pair
)
else
:
rand_portion
=
np
.
random
.
random
()
if
rand_portion
<
sample_ratio
:
result_phrases
.
append
(
pair
)
return
result_phrases
def
get_phrases_by_tree
(
tree
):
phrases
=
[]
if
tree
is
None
:
return
phrases
if
tree
.
is_leaf
():
pair
=
(
tree
.
word
,
int
(
tree
.
label
))
phrases
.
append
(
pair
)
return
phrases
left_child_phrases
=
get_phrases_by_tree
(
tree
.
subtrees
[
0
])
right_child_phrases
=
get_phrases_by_tree
(
tree
.
subtrees
[
1
])
phrases
.
extend
(
left_child_phrases
)
phrases
.
extend
(
right_child_phrases
)
sentence
=
get_sentence_by_tree
(
tree
)
pair
=
(
sentence
,
int
(
tree
.
label
))
phrases
.
append
(
pair
)
return
phrases
def
get_sentence_by_tree
(
tree
):
if
tree
is
None
:
return
""
if
tree
.
is_leaf
():
return
tree
.
word
left_sentence
=
get_sentence_by_tree
(
tree
.
subtrees
[
0
])
right_sentence
=
get_sentence_by_tree
(
tree
.
subtrees
[
1
])
sentence
=
left_sentence
+
" "
+
right_sentence
return
sentence
.
strip
()
def
get_word_id_dict
(
word_num_dict
,
word_id_dict
,
min_count
):
z
=
[
k
for
k
in
sorted
(
word_num_dict
.
keys
())]
for
word
in
z
:
count
=
word_num_dict
[
word
]
if
count
>=
min_count
:
index
=
len
(
word_id_dict
)
if
word
not
in
word_id_dict
:
word_id_dict
[
word
]
=
index
return
word_id_dict
def
load_word_num_dict
(
phrases
,
word_num_dict
):
for
sentence
,
_
in
phrases
:
words
=
sentence
.
split
(
" "
)
for
cur_word
in
words
:
word
=
cur_word
.
strip
()
word_num_dict
[
word
]
+=
1
return
word_num_dict
def
init_trainable_embedding
(
embedding_path
,
word_id_dict
,
embed_dim
=
300
):
word_embed_model
=
load_glove_model
(
embedding_path
,
embed_dim
)
assert
word_embed_model
[
"pool"
].
shape
[
1
]
==
embed_dim
embedding
=
np
.
random
.
random
([
len
(
word_id_dict
),
embed_dim
]).
astype
(
np
.
float32
)
/
2.0
-
0.25
embedding
[
0
]
=
np
.
zeros
(
embed_dim
)
# PAD
embedding
[
1
]
=
(
np
.
random
.
rand
(
embed_dim
)
-
0.5
)
/
2
# UNK
for
word
,
idx
in
word_id_dict
.
items
():
if
idx
==
0
or
idx
==
1
:
continue
if
word
in
word_embed_model
[
"mapping"
]:
embedding
[
idx
]
=
word_embed_model
[
"pool"
][
word_embed_model
[
"mapping"
][
word
]]
else
:
embedding
[
idx
]
=
np
.
random
.
rand
(
embed_dim
)
/
2.0
-
0.25
return
embedding
def
sst_get_trainable_data
(
phrases
,
word_id_dict
,
max_input_length
):
texts
,
labels
,
mask
=
[],
[],
[]
for
phrase
,
label
in
phrases
:
if
not
phrase
.
split
():
continue
phrase_split
,
mask_split
=
sst_get_id_input
(
phrase
,
word_id_dict
,
max_input_length
)
texts
.
append
(
phrase_split
)
labels
.
append
(
int
(
label
))
mask
.
append
(
mask_split
)
# field_input is mask
labels
=
np
.
array
(
labels
,
dtype
=
np
.
int64
)
texts
=
np
.
reshape
(
texts
,
[
-
1
,
max_input_length
]).
astype
(
np
.
int32
)
mask
=
np
.
reshape
(
mask
,
[
-
1
,
max_input_length
]).
astype
(
np
.
int32
)
return
SSTDataset
(
texts
,
mask
,
labels
)
def
load_glove_model
(
filename
,
embed_dim
):
if
os
.
path
.
exists
(
filename
+
".cache"
):
logger
.
info
(
"Found cache. Loading..."
)
with
open
(
filename
+
".cache"
,
"rb"
)
as
fp
:
return
pickle
.
load
(
fp
)
embedding
=
{
"mapping"
:
dict
(),
"pool"
:
[]}
with
open
(
filename
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
rstrip
(
"
\n
"
)
vocab_word
,
*
vec
=
line
.
rsplit
(
" "
,
maxsplit
=
embed_dim
)
assert
len
(
vec
)
==
300
,
"Unexpected line: '%s'"
%
line
embedding
[
"pool"
].
append
(
np
.
array
(
list
(
map
(
float
,
vec
)),
dtype
=
np
.
float32
))
embedding
[
"mapping"
][
vocab_word
]
=
i
embedding
[
"pool"
]
=
np
.
stack
(
embedding
[
"pool"
])
with
open
(
filename
+
".cache"
,
"wb"
)
as
fp
:
pickle
.
dump
(
embedding
,
fp
)
return
embedding
def
read_data_sst
(
data_path
,
max_input_length
=
64
,
min_count
=
1
,
train_with_valid
=
False
,
train_ratio
=
1.
,
valid_ratio
=
1.
,
is_binary
=
False
,
only_sentence
=
False
):
word_id_dict
=
dict
()
word_num_dict
=
Counter
()
sst_path
=
os
.
path
.
join
(
data_path
,
"sst"
)
logger
.
info
(
"Reading SST data..."
)
train_file_name
=
os
.
path
.
join
(
sst_path
,
"trees"
,
"train.txt"
)
valid_file_name
=
os
.
path
.
join
(
sst_path
,
"trees"
,
"dev.txt"
)
test_file_name
=
os
.
path
.
join
(
sst_path
,
"trees"
,
"test.txt"
)
train_trees
=
read_trees
(
train_file_name
)
train_phrases
=
sst_get_phrases
(
train_trees
,
train_ratio
,
is_binary
,
only_sentence
)
logger
.
info
(
"Finish load train phrases."
)
valid_trees
=
read_trees
(
valid_file_name
)
valid_phrases
=
sst_get_phrases
(
valid_trees
,
valid_ratio
,
is_binary
,
only_sentence
)
logger
.
info
(
"Finish load valid phrases."
)
if
train_with_valid
:
train_phrases
+=
valid_phrases
test_trees
=
read_trees
(
test_file_name
)
test_phrases
=
sst_get_phrases
(
test_trees
,
valid_ratio
,
is_binary
,
only_sentence
=
True
)
logger
.
info
(
"Finish load test phrases."
)
# get word_id_dict
word_id_dict
[
"<pad>"
]
=
0
word_id_dict
[
"<unknown>"
]
=
1
load_word_num_dict
(
train_phrases
,
word_num_dict
)
logger
.
info
(
"Finish load train words: %d."
,
len
(
word_num_dict
))
load_word_num_dict
(
valid_phrases
,
word_num_dict
)
load_word_num_dict
(
test_phrases
,
word_num_dict
)
logger
.
info
(
"Finish load valid+test words: %d."
,
len
(
word_num_dict
))
word_id_dict
=
get_word_id_dict
(
word_num_dict
,
word_id_dict
,
min_count
)
logger
.
info
(
"After trim vocab length: %d."
,
len
(
word_id_dict
))
logger
.
info
(
"Loading embedding..."
)
embedding
=
init_trainable_embedding
(
os
.
path
.
join
(
data_path
,
"glove.840B.300d.txt"
),
word_id_dict
)
logger
.
info
(
"Finish initialize word embedding."
)
dataset_train
=
sst_get_trainable_data
(
train_phrases
,
word_id_dict
,
max_input_length
)
logger
.
info
(
"Loaded %d training samples."
,
len
(
dataset_train
))
dataset_valid
=
sst_get_trainable_data
(
valid_phrases
,
word_id_dict
,
max_input_length
)
logger
.
info
(
"Loaded %d validation samples."
,
len
(
dataset_valid
))
dataset_test
=
sst_get_trainable_data
(
test_phrases
,
word_id_dict
,
max_input_length
)
logger
.
info
(
"Loaded %d test samples."
,
len
(
dataset_test
))
return
dataset_train
,
dataset_valid
,
dataset_test
,
torch
.
from_numpy
(
embedding
)
examples/nas/textnas/model.py
0 → 100644
View file @
70cee7d8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
nni.nas.pytorch
import
mutables
from
ops
import
ConvBN
,
LinearCombine
,
AvgPool
,
MaxPool
,
RNN
,
Attention
,
BatchNorm
from
utils
import
GlobalMaxPool
,
GlobalAvgPool
class
Layer
(
mutables
.
MutableScope
):
def
__init__
(
self
,
key
,
prev_keys
,
hidden_units
,
choose_from_k
,
cnn_keep_prob
,
lstm_keep_prob
,
att_keep_prob
,
att_mask
):
super
(
Layer
,
self
).
__init__
(
key
)
def
conv_shortcut
(
kernel_size
):
return
ConvBN
(
kernel_size
,
hidden_units
,
hidden_units
,
cnn_keep_prob
,
False
,
True
)
self
.
n_candidates
=
len
(
prev_keys
)
if
self
.
n_candidates
:
self
.
prec
=
mutables
.
InputChoice
(
choose_from
=
prev_keys
[
-
choose_from_k
:],
n_chosen
=
1
)
else
:
# first layer, skip input choice
self
.
prec
=
None
self
.
op
=
mutables
.
LayerChoice
([
conv_shortcut
(
1
),
conv_shortcut
(
3
),
conv_shortcut
(
5
),
conv_shortcut
(
7
),
AvgPool
(
3
,
False
,
True
),
MaxPool
(
3
,
False
,
True
),
RNN
(
hidden_units
,
lstm_keep_prob
),
Attention
(
hidden_units
,
4
,
att_keep_prob
,
att_mask
)
])
if
self
.
n_candidates
:
self
.
skipconnect
=
mutables
.
InputChoice
(
choose_from
=
prev_keys
)
else
:
self
.
skipconnect
=
None
self
.
bn
=
BatchNorm
(
hidden_units
,
False
,
True
)
def
forward
(
self
,
last_layer
,
prev_layers
,
mask
):
# pass an extra last_layer to deal with layer 0 (prev_layers is empty)
if
self
.
prec
is
None
:
prec
=
last_layer
else
:
prec
=
self
.
prec
(
prev_layers
[
-
self
.
prec
.
n_candidates
:])
# skip first
out
=
self
.
op
(
prec
,
mask
)
if
self
.
skipconnect
is
not
None
:
connection
=
self
.
skipconnect
(
prev_layers
[
-
self
.
skipconnect
.
n_candidates
:])
if
connection
is
not
None
:
out
+=
connection
out
=
self
.
bn
(
out
,
mask
)
return
out
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
embedding
,
hidden_units
=
256
,
num_layers
=
24
,
num_classes
=
5
,
choose_from_k
=
5
,
lstm_keep_prob
=
0.5
,
cnn_keep_prob
=
0.5
,
att_keep_prob
=
0.5
,
att_mask
=
True
,
embed_keep_prob
=
0.5
,
final_output_keep_prob
=
1.0
,
global_pool
=
"avg"
):
super
(
Model
,
self
).
__init__
()
self
.
embedding
=
nn
.
Embedding
.
from_pretrained
(
embedding
,
freeze
=
False
)
self
.
hidden_units
=
hidden_units
self
.
num_layers
=
num_layers
self
.
num_classes
=
num_classes
self
.
init_conv
=
ConvBN
(
1
,
self
.
embedding
.
embedding_dim
,
hidden_units
,
cnn_keep_prob
,
False
,
True
)
self
.
layers
=
nn
.
ModuleList
()
candidate_keys_pool
=
[]
for
layer_id
in
range
(
self
.
num_layers
):
k
=
"layer_{}"
.
format
(
layer_id
)
self
.
layers
.
append
(
Layer
(
k
,
candidate_keys_pool
,
hidden_units
,
choose_from_k
,
cnn_keep_prob
,
lstm_keep_prob
,
att_keep_prob
,
att_mask
))
candidate_keys_pool
.
append
(
k
)
self
.
linear_combine
=
LinearCombine
(
self
.
num_layers
)
self
.
linear_out
=
nn
.
Linear
(
self
.
hidden_units
,
self
.
num_classes
)
self
.
embed_dropout
=
nn
.
Dropout
(
p
=
1
-
embed_keep_prob
)
self
.
output_dropout
=
nn
.
Dropout
(
p
=
1
-
final_output_keep_prob
)
assert
global_pool
in
[
"max"
,
"avg"
]
if
global_pool
==
"max"
:
self
.
global_pool
=
GlobalMaxPool
()
elif
global_pool
==
"avg"
:
self
.
global_pool
=
GlobalAvgPool
()
def
forward
(
self
,
inputs
):
sent_ids
,
mask
=
inputs
seq
=
self
.
embedding
(
sent_ids
.
long
())
seq
=
self
.
embed_dropout
(
seq
)
seq
=
torch
.
transpose
(
seq
,
1
,
2
)
# from (N, L, C) -> (N, C, L)
x
=
self
.
init_conv
(
seq
,
mask
)
prev_layers
=
[]
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
prev_layers
,
mask
)
prev_layers
.
append
(
x
)
x
=
self
.
linear_combine
(
torch
.
stack
(
prev_layers
))
x
=
self
.
global_pool
(
x
,
mask
)
x
=
self
.
output_dropout
(
x
)
x
=
self
.
linear_out
(
x
)
return
x
examples/nas/textnas/ops.py
0 → 100644
View file @
70cee7d8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
utils
import
get_length
,
INF
class
Mask
(
nn
.
Module
):
def
forward
(
self
,
seq
,
mask
):
# seq: (N, C, L)
# mask: (N, L)
seq_mask
=
torch
.
unsqueeze
(
mask
,
2
)
seq_mask
=
torch
.
transpose
(
seq_mask
.
repeat
(
1
,
1
,
seq
.
size
()[
1
]),
1
,
2
)
return
seq
.
where
(
torch
.
eq
(
seq_mask
,
1
),
torch
.
zeros_like
(
seq
))
class
BatchNorm
(
nn
.
Module
):
def
__init__
(
self
,
num_features
,
pre_mask
,
post_mask
,
eps
=
1e-5
,
decay
=
0.9
,
affine
=
True
):
super
(
BatchNorm
,
self
).
__init__
()
self
.
mask_opt
=
Mask
()
self
.
pre_mask
=
pre_mask
self
.
post_mask
=
post_mask
self
.
bn
=
nn
.
BatchNorm1d
(
num_features
,
eps
=
eps
,
momentum
=
1.0
-
decay
,
affine
=
affine
)
def
forward
(
self
,
seq
,
mask
):
if
self
.
pre_mask
:
seq
=
self
.
mask_opt
(
seq
,
mask
)
seq
=
self
.
bn
(
seq
)
if
self
.
post_mask
:
seq
=
self
.
mask_opt
(
seq
,
mask
)
return
seq
class
ConvBN
(
nn
.
Module
):
def
__init__
(
self
,
kernal_size
,
in_channels
,
out_channels
,
cnn_keep_prob
,
pre_mask
,
post_mask
,
with_bn
=
True
,
with_relu
=
True
):
super
(
ConvBN
,
self
).
__init__
()
self
.
mask_opt
=
Mask
()
self
.
pre_mask
=
pre_mask
self
.
post_mask
=
post_mask
self
.
with_bn
=
with_bn
self
.
with_relu
=
with_relu
self
.
conv
=
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernal_size
,
1
,
bias
=
True
,
padding
=
(
kernal_size
-
1
)
//
2
)
self
.
dropout
=
nn
.
Dropout
(
p
=
(
1
-
cnn_keep_prob
))
if
with_bn
:
self
.
bn
=
BatchNorm
(
out_channels
,
not
post_mask
,
True
)
if
with_relu
:
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
seq
,
mask
):
if
self
.
pre_mask
:
seq
=
self
.
mask_opt
(
seq
,
mask
)
seq
=
self
.
conv
(
seq
)
if
self
.
post_mask
:
seq
=
self
.
mask_opt
(
seq
,
mask
)
if
self
.
with_bn
:
seq
=
self
.
bn
(
seq
,
mask
)
if
self
.
with_relu
:
seq
=
self
.
relu
(
seq
)
seq
=
self
.
dropout
(
seq
)
return
seq
class
AvgPool
(
nn
.
Module
):
def
__init__
(
self
,
kernal_size
,
pre_mask
,
post_mask
):
super
(
AvgPool
,
self
).
__init__
()
self
.
avg_pool
=
nn
.
AvgPool1d
(
kernal_size
,
1
,
padding
=
(
kernal_size
-
1
)
//
2
)
self
.
pre_mask
=
pre_mask
self
.
post_mask
=
post_mask
self
.
mask_opt
=
Mask
()
def
forward
(
self
,
seq
,
mask
):
if
self
.
pre_mask
:
seq
=
self
.
mask_opt
(
seq
,
mask
)
seq
=
self
.
avg_pool
(
seq
)
if
self
.
post_mask
:
seq
=
self
.
mask_opt
(
seq
,
mask
)
return
seq
class
MaxPool
(
nn
.
Module
):
def
__init__
(
self
,
kernal_size
,
pre_mask
,
post_mask
):
super
(
MaxPool
,
self
).
__init__
()
self
.
max_pool
=
nn
.
MaxPool1d
(
kernal_size
,
1
,
padding
=
(
kernal_size
-
1
)
//
2
)
self
.
pre_mask
=
pre_mask
self
.
post_mask
=
post_mask
self
.
mask_opt
=
Mask
()
def
forward
(
self
,
seq
,
mask
):
if
self
.
pre_mask
:
seq
=
self
.
mask_opt
(
seq
,
mask
)
seq
=
self
.
max_pool
(
seq
)
if
self
.
post_mask
:
seq
=
self
.
mask_opt
(
seq
,
mask
)
return
seq
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
num_units
,
num_heads
,
keep_prob
,
is_mask
):
super
(
Attention
,
self
).
__init__
()
self
.
num_heads
=
num_heads
self
.
keep_prob
=
keep_prob
self
.
linear_q
=
nn
.
Linear
(
num_units
,
num_units
)
self
.
linear_k
=
nn
.
Linear
(
num_units
,
num_units
)
self
.
linear_v
=
nn
.
Linear
(
num_units
,
num_units
)
self
.
bn
=
BatchNorm
(
num_units
,
True
,
is_mask
)
self
.
dropout
=
nn
.
Dropout
(
p
=
1
-
self
.
keep_prob
)
def
forward
(
self
,
seq
,
mask
):
in_c
=
seq
.
size
()[
1
]
seq
=
torch
.
transpose
(
seq
,
1
,
2
)
# (N, L, C)
queries
=
seq
keys
=
seq
num_heads
=
self
.
num_heads
# T_q = T_k = L
Q
=
F
.
relu
(
self
.
linear_q
(
seq
))
# (N, T_q, C)
K
=
F
.
relu
(
self
.
linear_k
(
seq
))
# (N, T_k, C)
V
=
F
.
relu
(
self
.
linear_v
(
seq
))
# (N, T_k, C)
# Split and concat
Q_
=
torch
.
cat
(
torch
.
split
(
Q
,
in_c
//
num_heads
,
dim
=
2
),
dim
=
0
)
# (h*N, T_q, C/h)
K_
=
torch
.
cat
(
torch
.
split
(
K
,
in_c
//
num_heads
,
dim
=
2
),
dim
=
0
)
# (h*N, T_k, C/h)
V_
=
torch
.
cat
(
torch
.
split
(
V
,
in_c
//
num_heads
,
dim
=
2
),
dim
=
0
)
# (h*N, T_k, C/h)
# Multiplication
outputs
=
torch
.
matmul
(
Q_
,
K_
.
transpose
(
1
,
2
))
# (h*N, T_q, T_k)
# Scale
outputs
=
outputs
/
(
K_
.
size
()[
-
1
]
**
0.5
)
# Key Masking
key_masks
=
mask
.
repeat
(
num_heads
,
1
)
# (h*N, T_k)
key_masks
=
torch
.
unsqueeze
(
key_masks
,
1
)
# (h*N, 1, T_k)
key_masks
=
key_masks
.
repeat
(
1
,
queries
.
size
()[
1
],
1
)
# (h*N, T_q, T_k)
paddings
=
torch
.
ones_like
(
outputs
)
*
(
-
INF
)
# extremely small value
outputs
=
torch
.
where
(
torch
.
eq
(
key_masks
,
0
),
paddings
,
outputs
)
query_masks
=
mask
.
repeat
(
num_heads
,
1
)
# (h*N, T_q)
query_masks
=
torch
.
unsqueeze
(
query_masks
,
-
1
)
# (h*N, T_q, 1)
query_masks
=
query_masks
.
repeat
(
1
,
1
,
keys
.
size
()[
1
]).
float
()
# (h*N, T_q, T_k)
att_scores
=
F
.
softmax
(
outputs
,
dim
=-
1
)
*
query_masks
# (h*N, T_q, T_k)
att_scores
=
self
.
dropout
(
att_scores
)
# Weighted sum
x_outputs
=
torch
.
matmul
(
att_scores
,
V_
)
# (h*N, T_q, C/h)
# Restore shape
x_outputs
=
torch
.
cat
(
torch
.
split
(
x_outputs
,
x_outputs
.
size
()[
0
]
//
num_heads
,
dim
=
0
),
dim
=
2
)
# (N, T_q, C)
x
=
torch
.
transpose
(
x_outputs
,
1
,
2
)
# (N, C, L)
x
=
self
.
bn
(
x
,
mask
)
return
x
class
RNN
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
output_keep_prob
):
super
(
RNN
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
bid_rnn
=
nn
.
GRU
(
hidden_size
,
hidden_size
,
batch_first
=
True
,
bidirectional
=
True
)
self
.
output_keep_prob
=
output_keep_prob
self
.
out_dropout
=
nn
.
Dropout
(
p
=
(
1
-
self
.
output_keep_prob
))
def
forward
(
self
,
seq
,
mask
):
# seq: (N, C, L)
# mask: (N, L)
max_len
=
seq
.
size
()[
2
]
length
=
get_length
(
mask
)
seq
=
torch
.
transpose
(
seq
,
1
,
2
)
# to (N, L, C)
packed_seq
=
nn
.
utils
.
rnn
.
pack_padded_sequence
(
seq
,
length
,
batch_first
=
True
,
enforce_sorted
=
False
)
outputs
,
_
=
self
.
bid_rnn
(
packed_seq
)
outputs
=
nn
.
utils
.
rnn
.
pad_packed_sequence
(
outputs
,
batch_first
=
True
,
total_length
=
max_len
)[
0
]
outputs
=
outputs
.
view
(
-
1
,
max_len
,
2
,
self
.
hidden_size
).
sum
(
2
)
# (N, L, C)
outputs
=
self
.
out_dropout
(
outputs
)
# output dropout
return
torch
.
transpose
(
outputs
,
1
,
2
)
# back to: (N, C, L)
class
LinearCombine
(
nn
.
Module
):
def
__init__
(
self
,
layers_num
,
trainable
=
True
,
input_aware
=
False
,
word_level
=
False
):
super
(
LinearCombine
,
self
).
__init__
()
self
.
input_aware
=
input_aware
self
.
word_level
=
word_level
if
input_aware
:
raise
NotImplementedError
(
"Input aware is not supported."
)
self
.
w
=
nn
.
Parameter
(
torch
.
full
((
layers_num
,
1
,
1
,
1
),
1.0
/
layers_num
),
requires_grad
=
trainable
)
def
forward
(
self
,
seq
):
nw
=
F
.
softmax
(
self
.
w
,
dim
=
0
)
seq
=
torch
.
mul
(
seq
,
nw
)
seq
=
torch
.
sum
(
seq
,
dim
=
0
)
return
seq
examples/nas/textnas/search.py
0 → 100644
View file @
70cee7d8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
os
import
random
from
argparse
import
ArgumentParser
from
itertools
import
cycle
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
nni.nas.pytorch.enas
import
EnasMutator
,
EnasTrainer
from
nni.nas.pytorch.callbacks
import
LRSchedulerCallback
from
dataloader
import
read_data_sst
from
model
import
Model
from
utils
import
accuracy
logger
=
logging
.
getLogger
(
"nni.textnas"
)
class
TextNASTrainer
(
EnasTrainer
):
def
__init__
(
self
,
*
args
,
train_loader
=
None
,
valid_loader
=
None
,
test_loader
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
train_loader
=
train_loader
self
.
valid_loader
=
valid_loader
self
.
test_loader
=
test_loader
def
init_dataloader
(
self
):
pass
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"textnas"
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
50
,
type
=
int
)
parser
.
add_argument
(
"--seed"
,
default
=
1234
,
type
=
int
)
parser
.
add_argument
(
"--epochs"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--lr"
,
default
=
5e-3
,
type
=
float
)
args
=
parser
.
parse_args
()
torch
.
manual_seed
(
args
.
seed
)
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
train_dataset
,
valid_dataset
,
test_dataset
,
embedding
=
read_data_sst
(
"data"
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
batch_size
,
num_workers
=
4
,
shuffle
=
True
)
valid_loader
=
torch
.
utils
.
data
.
DataLoader
(
valid_dataset
,
batch_size
=
args
.
batch_size
,
num_workers
=
4
,
shuffle
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
test_dataset
,
batch_size
=
args
.
batch_size
,
num_workers
=
4
)
train_loader
,
valid_loader
=
cycle
(
train_loader
),
cycle
(
valid_loader
)
model
=
Model
(
embedding
)
mutator
=
EnasMutator
(
model
,
temperature
=
None
,
tanh_constant
=
None
,
entropy_reduction
=
"mean"
)
criterion
=
nn
.
CrossEntropyLoss
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
,
eps
=
1e-3
,
weight_decay
=
2e-6
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
T_max
=
args
.
epochs
,
eta_min
=
1e-5
)
trainer
=
TextNASTrainer
(
model
,
loss
=
criterion
,
metrics
=
lambda
output
,
target
:
{
"acc"
:
accuracy
(
output
,
target
)},
reward_function
=
accuracy
,
optimizer
=
optimizer
,
callbacks
=
[
LRSchedulerCallback
(
lr_scheduler
)],
batch_size
=
args
.
batch_size
,
num_epochs
=
args
.
epochs
,
dataset_train
=
None
,
dataset_valid
=
None
,
train_loader
=
train_loader
,
valid_loader
=
valid_loader
,
test_loader
=
test_loader
,
log_frequency
=
args
.
log_frequency
,
mutator
=
mutator
,
mutator_lr
=
2e-3
,
mutator_steps
=
500
,
mutator_steps_aggregate
=
1
,
child_steps
=
3000
,
baseline_decay
=
0.99
,
test_arc_per_epoch
=
10
)
trainer
.
train
()
os
.
makedirs
(
"checkpoints"
,
exist_ok
=
True
)
for
i
in
range
(
20
):
trainer
.
export
(
os
.
path
.
join
(
"checkpoints"
,
"architecture_%02d.json"
%
i
))
examples/nas/textnas/utils.py
0 → 100644
View file @
70cee7d8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
torch
import
torch.nn
as
nn
INF
=
1E10
EPS
=
1E-12
logger
=
logging
.
getLogger
(
"nni.textnas"
)
def
get_length
(
mask
):
length
=
torch
.
sum
(
mask
,
1
)
length
=
length
.
long
()
return
length
class
GlobalAvgPool
(
nn
.
Module
):
def
forward
(
self
,
x
,
mask
):
x
=
torch
.
sum
(
x
,
2
)
length
=
torch
.
sum
(
mask
,
1
,
keepdim
=
True
).
float
()
length
+=
torch
.
eq
(
length
,
0.0
).
float
()
*
EPS
length
=
length
.
repeat
(
1
,
x
.
size
()[
1
])
x
/=
length
return
x
class
GlobalMaxPool
(
nn
.
Module
):
def
forward
(
self
,
x
,
mask
):
mask
=
torch
.
eq
(
mask
.
float
(),
0.0
).
long
()
mask
=
torch
.
unsqueeze
(
mask
,
dim
=
1
).
repeat
(
1
,
x
.
size
()[
1
],
1
)
mask
*=
-
INF
x
+=
mask
x
,
_
=
torch
.
max
(
x
+
mask
,
2
)
return
x
class
IteratorWrapper
:
def
__init__
(
self
,
loader
):
self
.
loader
=
loader
self
.
iterator
=
None
def
__iter__
(
self
):
self
.
iterator
=
iter
(
self
.
loader
)
return
self
def
__len__
(
self
):
return
len
(
self
.
loader
)
def
__next__
(
self
):
data
=
next
(
self
.
iterator
)
text
,
length
=
data
.
text
max_length
=
text
.
size
(
1
)
label
=
data
.
label
-
1
bs
=
label
.
size
(
0
)
mask
=
torch
.
arange
(
max_length
,
device
=
length
.
device
).
unsqueeze
(
0
).
repeat
(
bs
,
1
)
mask
=
mask
<
length
.
unsqueeze
(
-
1
).
repeat
(
1
,
max_length
)
return
(
text
,
mask
),
label
def
accuracy
(
output
,
target
):
batch_size
=
target
.
size
(
0
)
_
,
predicted
=
torch
.
max
(
output
.
data
,
1
)
return
(
predicted
==
target
).
sum
().
item
()
/
batch_size
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
View file @
70cee7d8
...
@@ -30,7 +30,7 @@ class StackedLSTMCell(nn.Module):
...
@@ -30,7 +30,7 @@ class StackedLSTMCell(nn.Module):
class
EnasMutator
(
Mutator
):
class
EnasMutator
(
Mutator
):
def
__init__
(
self
,
model
,
lstm_size
=
64
,
lstm_num_layers
=
1
,
tanh_constant
=
1.5
,
cell_exit_extra_step
=
False
,
def
__init__
(
self
,
model
,
lstm_size
=
64
,
lstm_num_layers
=
1
,
tanh_constant
=
1.5
,
cell_exit_extra_step
=
False
,
skip_target
=
0.4
,
branch_bias
=
0.25
):
skip_target
=
0.4
,
temperature
=
None
,
branch_bias
=
0.25
,
entropy_reduction
=
"sum"
):
"""
"""
Initialize a EnasMutator.
Initialize a EnasMutator.
...
@@ -49,17 +49,22 @@ class EnasMutator(Mutator):
...
@@ -49,17 +49,22 @@ class EnasMutator(Mutator):
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float
skip_target : float
Target probability that skipconnect will appear.
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
branch_bias : float
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
receive a bias of ``-self.branch_bias``.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
"""
super
().
__init__
(
model
)
super
().
__init__
(
model
)
self
.
lstm_size
=
lstm_size
self
.
lstm_size
=
lstm_size
self
.
lstm_num_layers
=
lstm_num_layers
self
.
lstm_num_layers
=
lstm_num_layers
self
.
tanh_constant
=
tanh_constant
self
.
tanh_constant
=
tanh_constant
self
.
temperature
=
temperature
self
.
cell_exit_extra_step
=
cell_exit_extra_step
self
.
cell_exit_extra_step
=
cell_exit_extra_step
self
.
skip_target
=
skip_target
self
.
skip_target
=
skip_target
self
.
branch_bias
=
branch_bias
self
.
branch_bias
=
branch_bias
...
@@ -70,6 +75,8 @@ class EnasMutator(Mutator):
...
@@ -70,6 +75,8 @@ class EnasMutator(Mutator):
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
tensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
requires_grad
=
False
)
# pylint: disable=not-callable
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
tensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
requires_grad
=
False
)
# pylint: disable=not-callable
assert
entropy_reduction
in
[
"sum"
,
"mean"
],
"Entropy reduction must be one of sum and mean."
self
.
entropy_reduction
=
torch
.
sum
if
entropy_reduction
==
"sum"
else
torch
.
mean
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
(
reduction
=
"none"
)
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
(
reduction
=
"none"
)
self
.
bias_dict
=
nn
.
ParameterDict
()
self
.
bias_dict
=
nn
.
ParameterDict
()
...
@@ -135,15 +142,17 @@ class EnasMutator(Mutator):
...
@@ -135,15 +142,17 @@ class EnasMutator(Mutator):
def
_sample_layer_choice
(
self
,
mutable
):
def
_sample_layer_choice
(
self
,
mutable
):
self
.
_lstm_next_step
()
self
.
_lstm_next_step
()
logit
=
self
.
soft
(
self
.
_h
[
-
1
])
logit
=
self
.
soft
(
self
.
_h
[
-
1
])
if
self
.
temperature
is
not
None
:
logit
/=
self
.
temperature
if
self
.
tanh_constant
is
not
None
:
if
self
.
tanh_constant
is
not
None
:
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
if
mutable
.
key
in
self
.
bias_dict
:
if
mutable
.
key
in
self
.
bias_dict
:
logit
+=
self
.
bias_dict
[
mutable
.
key
]
logit
+=
self
.
bias_dict
[
mutable
.
key
]
branch_id
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
branch_id
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
branch_id
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
branch_id
)
self
.
sample_log_prob
+=
torch
.
sum
(
log_prob
)
self
.
sample_log_prob
+=
self
.
entropy_reduction
(
log_prob
)
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
self
.
sample_entropy
+=
torch
.
sum
(
entropy
)
self
.
sample_entropy
+=
self
.
entropy_reduction
(
entropy
)
self
.
_inputs
=
self
.
embedding
(
branch_id
)
self
.
_inputs
=
self
.
embedding
(
branch_id
)
return
F
.
one_hot
(
branch_id
,
num_classes
=
self
.
max_layer_choice
).
bool
().
view
(
-
1
)
return
F
.
one_hot
(
branch_id
,
num_classes
=
self
.
max_layer_choice
).
bool
().
view
(
-
1
)
...
@@ -158,6 +167,8 @@ class EnasMutator(Mutator):
...
@@ -158,6 +167,8 @@ class EnasMutator(Mutator):
query
=
torch
.
cat
(
query
,
0
)
query
=
torch
.
cat
(
query
,
0
)
query
=
torch
.
tanh
(
query
+
self
.
attn_query
(
self
.
_h
[
-
1
]))
query
=
torch
.
tanh
(
query
+
self
.
attn_query
(
self
.
_h
[
-
1
]))
query
=
self
.
v_attn
(
query
)
query
=
self
.
v_attn
(
query
)
if
self
.
temperature
is
not
None
:
query
/=
self
.
temperature
if
self
.
tanh_constant
is
not
None
:
if
self
.
tanh_constant
is
not
None
:
query
=
self
.
tanh_constant
*
torch
.
tanh
(
query
)
query
=
self
.
tanh_constant
*
torch
.
tanh
(
query
)
...
@@ -178,7 +189,7 @@ class EnasMutator(Mutator):
...
@@ -178,7 +189,7 @@ class EnasMutator(Mutator):
log_prob
=
self
.
cross_entropy_loss
(
logit
,
index
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
index
)
self
.
_inputs
=
anchors
[
index
.
item
()]
self
.
_inputs
=
anchors
[
index
.
item
()]
self
.
sample_log_prob
+=
torch
.
sum
(
log_prob
)
self
.
sample_log_prob
+=
self
.
entropy_reduction
(
log_prob
)
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
self
.
sample_entropy
+=
torch
.
sum
(
entropy
)
self
.
sample_entropy
+=
self
.
entropy_reduction
(
entropy
)
return
skip
.
bool
()
return
skip
.
bool
()
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
View file @
70cee7d8
...
@@ -2,11 +2,14 @@
...
@@ -2,11 +2,14 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
from
itertools
import
cycle
import
torch
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.pytorch.utils
import
AverageMeterGroup
from
nni.nas.pytorch.utils
import
AverageMeterGroup
,
to_device
from
.mutator
import
EnasMutator
from
.mutator
import
EnasMutator
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -16,8 +19,9 @@ class EnasTrainer(Trainer):
...
@@ -16,8 +19,9 @@ class EnasTrainer(Trainer):
def
__init__
(
self
,
model
,
loss
,
metrics
,
reward_function
,
def
__init__
(
self
,
model
,
loss
,
metrics
,
reward_function
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
child_steps
=
500
,
mutator_lr
=
0.00035
,
mutator_steps_aggregate
=
20
,
mutator_steps
=
50
,
aux_weight
=
0.4
):
mutator_lr
=
0.00035
,
mutator_steps_aggregate
=
20
,
mutator_steps
=
50
,
aux_weight
=
0.4
,
test_arc_per_epoch
=
1
):
"""
"""
Initialize an EnasTrainer.
Initialize an EnasTrainer.
...
@@ -57,6 +61,8 @@ class EnasTrainer(Trainer):
...
@@ -57,6 +61,8 @@ class EnasTrainer(Trainer):
Weight of skip penalty loss.
Weight of skip penalty loss.
baseline_decay : float
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
child_steps : int
How many mini-batches for model training per epoch.
mutator_lr : float
mutator_lr : float
Learning rate for RL controller.
Learning rate for RL controller.
mutator_steps_aggregate : int
mutator_steps_aggregate : int
...
@@ -65,12 +71,16 @@ class EnasTrainer(Trainer):
...
@@ -65,12 +71,16 @@ class EnasTrainer(Trainer):
Number of mini-batches for each epoch of RL controller learning.
Number of mini-batches for each epoch of RL controller learning.
aux_weight : float
aux_weight : float
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
test_arc_per_epoch : int
How many architectures are chosen for direct test after each epoch.
"""
"""
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
EnasMutator
(
model
),
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
EnasMutator
(
model
),
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
self
.
reward_function
=
reward_function
self
.
reward_function
=
reward_function
self
.
mutator_optim
=
optim
.
Adam
(
self
.
mutator
.
parameters
(),
lr
=
mutator_lr
)
self
.
mutator_optim
=
optim
.
Adam
(
self
.
mutator
.
parameters
(),
lr
=
mutator_lr
)
self
.
batch_size
=
batch_size
self
.
workers
=
workers
self
.
entropy_weight
=
entropy_weight
self
.
entropy_weight
=
entropy_weight
self
.
skip_weight
=
skip_weight
self
.
skip_weight
=
skip_weight
...
@@ -78,32 +88,40 @@ class EnasTrainer(Trainer):
...
@@ -78,32 +88,40 @@ class EnasTrainer(Trainer):
self
.
baseline
=
0.
self
.
baseline
=
0.
self
.
mutator_steps_aggregate
=
mutator_steps_aggregate
self
.
mutator_steps_aggregate
=
mutator_steps_aggregate
self
.
mutator_steps
=
mutator_steps
self
.
mutator_steps
=
mutator_steps
self
.
child_steps
=
child_steps
self
.
aux_weight
=
aux_weight
self
.
aux_weight
=
aux_weight
self
.
test_arc_per_epoch
=
test_arc_per_epoch
self
.
init_dataloader
()
def
init_dataloader
(
self
):
n_train
=
len
(
self
.
dataset_train
)
n_train
=
len
(
self
.
dataset_train
)
split
=
n_train
//
10
split
=
n_train
//
10
indices
=
list
(
range
(
n_train
))
indices
=
list
(
range
(
n_train
))
train_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[:
-
split
])
train_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[:
-
split
])
valid_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[
-
split
:])
valid_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[
-
split
:])
self
.
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
self
.
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
batch_size
=
batch_size
,
batch_size
=
self
.
batch_size
,
sampler
=
train_sampler
,
sampler
=
train_sampler
,
num_workers
=
workers
)
num_workers
=
self
.
workers
)
self
.
valid_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
self
.
valid_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
batch_size
=
batch_size
,
batch_size
=
self
.
batch_size
,
sampler
=
valid_sampler
,
sampler
=
valid_sampler
,
num_workers
=
workers
)
num_workers
=
self
.
workers
)
self
.
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_valid
,
self
.
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_valid
,
batch_size
=
batch_size
,
batch_size
=
self
.
batch_size
,
num_workers
=
workers
)
num_workers
=
self
.
workers
)
self
.
train_loader
=
cycle
(
self
.
train_loader
)
self
.
valid_loader
=
cycle
(
self
.
valid_loader
)
def
train_one_epoch
(
self
,
epoch
):
def
train_one_epoch
(
self
,
epoch
):
# Sample model and train
# Sample model and train
self
.
model
.
train
()
self
.
model
.
train
()
self
.
mutator
.
eval
()
self
.
mutator
.
eval
()
meters
=
AverageMeterGroup
()
meters
=
AverageMeterGroup
()
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
train_loader
):
for
step
in
range
(
1
,
self
.
child_steps
+
1
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
x
,
y
=
next
(
self
.
train_loader
)
x
,
y
=
to_device
(
x
,
self
.
device
),
to_device
(
y
,
self
.
device
)
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -119,55 +137,71 @@ class EnasTrainer(Trainer):
...
@@ -119,55 +137,71 @@ class EnasTrainer(Trainer):
loss
=
self
.
loss
(
logits
,
y
)
loss
=
self
.
loss
(
logits
,
y
)
loss
=
loss
+
self
.
aux_weight
*
aux_loss
loss
=
loss
+
self
.
aux_weight
*
aux_loss
loss
.
backward
()
loss
.
backward
()
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
parameters
(),
5.
)
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
metrics
[
"loss"
]
=
loss
.
item
()
metrics
[
"loss"
]
=
loss
.
item
()
meters
.
update
(
metrics
)
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"Model Epoch [%
s
/%
s
] Step [%
s
/%
s
] %s"
,
epoch
+
1
,
logger
.
info
(
"Model Epoch [%
d
/%
d
] Step [%
d
/%
d
] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
+
1
,
len
(
self
.
train_loader
)
,
meters
)
self
.
num_epochs
,
step
,
self
.
child_steps
,
meters
)
# Train sampler (mutator)
# Train sampler (mutator)
self
.
model
.
eval
()
self
.
model
.
eval
()
self
.
mutator
.
train
()
self
.
mutator
.
train
()
meters
=
AverageMeterGroup
()
meters
=
AverageMeterGroup
()
mutator_step
,
total_mutator_steps
=
0
,
self
.
mutator_steps
*
self
.
mutator_steps_aggregate
for
mutator_step
in
range
(
1
,
self
.
mutator_steps
+
1
):
while
mutator_step
<
total_mutator_steps
:
self
.
mutator_optim
.
zero_grad
()
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
valid_loader
):
for
step
in
range
(
1
,
self
.
mutator_steps_aggregate
+
1
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
x
,
y
=
next
(
self
.
valid_loader
)
x
,
y
=
to_device
(
x
,
self
.
device
),
to_device
(
y
,
self
.
device
)
self
.
mutator
.
reset
()
self
.
mutator
.
reset
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
logits
=
self
.
model
(
x
)
logits
=
self
.
model
(
x
)
metrics
=
self
.
metrics
(
logits
,
y
)
metrics
=
self
.
metrics
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
if
self
.
entropy_weight
is
not
None
:
if
self
.
entropy_weight
:
reward
+=
self
.
entropy_weight
*
self
.
mutator
.
sample_entropy
reward
+=
self
.
entropy_weight
*
self
.
mutator
.
sample_entropy
.
item
()
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
self
.
baseline
=
self
.
baseline
.
detach
().
item
()
loss
=
self
.
mutator
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
loss
=
self
.
mutator
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
if
self
.
skip_weight
:
if
self
.
skip_weight
:
loss
+=
self
.
skip_weight
*
self
.
mutator
.
sample_skip_penalty
loss
+=
self
.
skip_weight
*
self
.
mutator
.
sample_skip_penalty
metrics
[
"reward"
]
=
reward
metrics
[
"reward"
]
=
reward
metrics
[
"loss"
]
=
loss
.
item
()
metrics
[
"loss"
]
=
loss
.
item
()
metrics
[
"ent"
]
=
self
.
mutator
.
sample_entropy
.
item
()
metrics
[
"ent"
]
=
self
.
mutator
.
sample_entropy
.
item
()
metrics
[
"log_prob"
]
=
self
.
mutator
.
sample_log_prob
.
item
()
metrics
[
"baseline"
]
=
self
.
baseline
metrics
[
"baseline"
]
=
self
.
baseline
metrics
[
"skip"
]
=
self
.
mutator
.
sample_skip_penalty
metrics
[
"skip"
]
=
self
.
mutator
.
sample_skip_penalty
loss
=
loss
/
self
.
mutator_steps_aggregate
loss
/
=
self
.
mutator_steps_aggregate
loss
.
backward
()
loss
.
backward
()
meters
.
update
(
metrics
)
meters
.
update
(
metrics
)
if
mutator_step
%
self
.
mutator_steps_aggregate
==
0
:
cur_step
=
step
+
(
mutator_step
-
1
)
*
self
.
mutator_steps_aggregate
self
.
mutator_optim
.
step
()
if
self
.
log_frequency
is
not
None
and
cur_step
%
self
.
log_frequency
==
0
:
self
.
mutator_optim
.
zero_grad
()
logger
.
info
(
"RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s"
,
epoch
+
1
,
self
.
num_epochs
,
mutator_step
,
self
.
mutator_steps
,
step
,
self
.
mutator_steps_aggregate
,
meters
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
nn
.
utils
.
clip_grad_norm_
(
self
.
mutator
.
parameters
(),
5.
)
logger
.
info
(
"RL Epoch [%s/%s] Step [%s/%s] %s"
,
epoch
+
1
,
self
.
num_epochs
,
self
.
mutator_optim
.
step
()
mutator_step
//
self
.
mutator_steps_aggregate
+
1
,
self
.
mutator_steps
,
meters
)
mutator_step
+=
1
if
mutator_step
>=
total_mutator_steps
:
break
def
validate_one_epoch
(
self
,
epoch
):
def
validate_one_epoch
(
self
,
epoch
):
pass
with
torch
.
no_grad
():
for
arc_id
in
range
(
self
.
test_arc_per_epoch
):
meters
=
AverageMeterGroup
()
for
x
,
y
in
self
.
test_loader
:
x
,
y
=
to_device
(
x
,
self
.
device
),
to_device
(
y
,
self
.
device
)
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
if
isinstance
(
logits
,
tuple
):
logits
,
_
=
logits
metrics
=
self
.
metrics
(
logits
,
y
)
loss
=
self
.
loss
(
logits
,
y
)
metrics
[
"loss"
]
=
loss
.
item
()
meters
.
update
(
metrics
)
logger
.
info
(
"Test Epoch [%d/%d] Arc [%d/%d] Summary %s"
,
epoch
+
1
,
self
.
num_epochs
,
arc_id
+
1
,
self
.
test_arc_per_epoch
,
meters
.
summary
())
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
70cee7d8
...
@@ -159,7 +159,7 @@ class InputChoice(Mutable):
...
@@ -159,7 +159,7 @@ class InputChoice(Mutable):
"than number of candidates."
"than number of candidates."
self
.
n_candidates
=
n_candidates
self
.
n_candidates
=
n_candidates
self
.
choose_from
=
choose_from
self
.
choose_from
=
choose_from
.
copy
()
self
.
n_chosen
=
n_chosen
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
self
.
return_mask
=
return_mask
...
...
src/sdk/pynni/nni/nas/pytorch/trainer.py
View file @
70cee7d8
...
@@ -96,12 +96,12 @@ class Trainer(BaseTrainer):
...
@@ -96,12 +96,12 @@ class Trainer(BaseTrainer):
callback
.
on_epoch_begin
(
epoch
)
callback
.
on_epoch_begin
(
epoch
)
# training
# training
_logger
.
info
(
"Epoch %d Training"
,
epoch
)
_logger
.
info
(
"Epoch %d Training"
,
epoch
+
1
)
self
.
train_one_epoch
(
epoch
)
self
.
train_one_epoch
(
epoch
)
if
validate
:
if
validate
:
# validation
# validation
_logger
.
info
(
"Epoch %d Validating"
,
epoch
)
_logger
.
info
(
"Epoch %d Validating"
,
epoch
+
1
)
self
.
validate_one_epoch
(
epoch
)
self
.
validate_one_epoch
(
epoch
)
for
callback
in
self
.
callbacks
:
for
callback
in
self
.
callbacks
:
...
...
src/sdk/pynni/nni/nas/pytorch/utils.py
View file @
70cee7d8
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
import
logging
import
logging
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
torch
_counter
=
0
_counter
=
0
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -15,7 +17,22 @@ def global_mutable_counting():
...
@@ -15,7 +17,22 @@ def global_mutable_counting():
return
_counter
return
_counter
def
to_device
(
obj
,
device
):
if
torch
.
is_tensor
(
obj
):
return
obj
.
to
(
device
)
if
isinstance
(
obj
,
tuple
):
return
tuple
(
to_device
(
t
,
device
)
for
t
in
obj
)
if
isinstance
(
obj
,
list
):
return
[
to_device
(
t
,
device
)
for
t
in
obj
]
if
isinstance
(
obj
,
dict
):
return
{
k
:
to_device
(
v
,
device
)
for
k
,
v
in
obj
.
items
()}
if
isinstance
(
obj
,
(
int
,
float
,
str
)):
return
obj
raise
ValueError
(
"'%s' has unsupported type '%s'"
%
(
obj
,
type
(
obj
)))
class
AverageMeterGroup
:
class
AverageMeterGroup
:
"""Average meter group for multiple average meters"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
meters
=
OrderedDict
()
self
.
meters
=
OrderedDict
()
...
@@ -33,7 +50,10 @@ class AverageMeterGroup:
...
@@ -33,7 +50,10 @@ class AverageMeterGroup:
return
self
.
meters
[
item
]
return
self
.
meters
[
item
]
def
__str__
(
self
):
def
__str__
(
self
):
return
" "
.
join
(
str
(
v
)
for
_
,
v
in
self
.
meters
.
items
())
return
" "
.
join
(
str
(
v
)
for
v
in
self
.
meters
.
values
())
def
summary
(
self
):
return
" "
.
join
(
v
.
summary
()
for
v
in
self
.
meters
.
values
())
class
AverageMeter
:
class
AverageMeter
:
...
@@ -72,6 +92,10 @@ class AverageMeter:
...
@@ -72,6 +92,10 @@ class AverageMeter:
fmtstr
=
'{name} {val'
+
self
.
fmt
+
'} ({avg'
+
self
.
fmt
+
'})'
fmtstr
=
'{name} {val'
+
self
.
fmt
+
'} ({avg'
+
self
.
fmt
+
'})'
return
fmtstr
.
format
(
**
self
.
__dict__
)
return
fmtstr
.
format
(
**
self
.
__dict__
)
def
summary
(
self
):
fmtstr
=
'{name}: {avg'
+
self
.
fmt
+
'}'
return
fmtstr
.
format
(
**
self
.
__dict__
)
class
StructuredMutableTreeNode
:
class
StructuredMutableTreeNode
:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment