Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
67287997
Unverified
Commit
67287997
authored
Apr 16, 2020
by
SparkSnail
Committed by
GitHub
Apr 16, 2020
Browse files
Merge pull request #241 from microsoft/master
merge master
parents
b4773e1e
f8d42a33
Changes
74
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
909 additions
and
668 deletions
+909
-668
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+5
-5
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
+2
-0
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+24
-0
src/webui/package.json
src/webui/package.json
+3
-0
src/webui/src/App.tsx
src/webui/src/App.tsx
+33
-20
src/webui/yarn.lock
src/webui/yarn.lock
+800
-627
test/config/naive_trial/naive_trial.py
test/config/naive_trial/naive_trial.py
+14
-0
test/config/pr_tests.yml
test/config/pr_tests.yml
+11
-5
test/pipelines/pipelines-it-local-windows.yml
test/pipelines/pipelines-it-local-windows.yml
+1
-1
test/pipelines/pipelines-it-pai-windows.yml
test/pipelines/pipelines-it-pai-windows.yml
+2
-1
test/pipelines/pipelines-it-remote-windows.yml
test/pipelines/pipelines-it-remote-windows.yml
+0
-6
test/scripts/model_compression.sh
test/scripts/model_compression.sh
+9
-0
test/scripts/nas.sh
test/scripts/nas.sh
+4
-3
No files found.
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
67287997
...
...
@@ -212,7 +212,7 @@ class AGP_Pruner(Pruner):
if
epoch
>
0
:
self
.
now_epoch
=
epoch
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
if_calculated
.
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callabl
e
wrapper
.
if_calculated
=
Fals
e
class
SlimPruner
(
Pruner
):
"""
...
...
@@ -329,10 +329,6 @@ class LotteryTicketPruner(Pruner):
reset_weights : bool
Whether reset weights and optimizer at the beginning of each round.
"""
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
curr_prune_iteration
=
None
self
.
prune_iterations
=
config_list
[
0
][
'prune_iterations'
]
# save init weights and optimizer
self
.
reset_weights
=
reset_weights
if
self
.
reset_weights
:
...
...
@@ -344,6 +340,10 @@ class LotteryTicketPruner(Pruner):
if
lr_scheduler
is
not
None
:
self
.
_scheduler_state
=
copy
.
deepcopy
(
lr_scheduler
.
state_dict
())
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
curr_prune_iteration
=
None
self
.
prune_iterations
=
config_list
[
0
][
'prune_iterations'
]
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
...
...
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
67287997
...
...
@@ -129,6 +129,7 @@ class DartsTrainer(Trainer):
self
.
mutator
.
reset
()
logits
=
self
.
model
(
X
)
loss
=
self
.
loss
(
logits
,
y
)
self
.
_write_graph_status
()
return
logits
,
loss
def
_backward
(
self
,
val_X
,
val_y
):
...
...
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
View file @
67287997
...
...
@@ -126,6 +126,7 @@ class EnasTrainer(Trainer):
with
torch
.
no_grad
():
self
.
mutator
.
reset
()
self
.
_write_graph_status
()
logits
=
self
.
model
(
x
)
if
isinstance
(
logits
,
tuple
):
...
...
@@ -159,6 +160,7 @@ class EnasTrainer(Trainer):
self
.
mutator
.
reset
()
with
torch
.
no_grad
():
logits
=
self
.
model
(
x
)
self
.
_write_graph_status
()
metrics
=
self
.
metrics
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
if
self
.
entropy_weight
:
...
...
src/sdk/pynni/nni/nas/pytorch/trainer.py
View file @
67287997
...
...
@@ -3,6 +3,8 @@
import
json
import
logging
import
os
import
time
from
abc
import
abstractmethod
import
torch
...
...
@@ -90,6 +92,9 @@ class Trainer(BaseTrainer):
self
.
batch_size
=
batch_size
self
.
workers
=
workers
self
.
log_frequency
=
log_frequency
self
.
log_dir
=
os
.
path
.
join
(
"logs"
,
str
(
time
.
time
()))
os
.
makedirs
(
self
.
log_dir
,
exist_ok
=
True
)
self
.
status_writer
=
open
(
os
.
path
.
join
(
self
.
log_dir
,
"log"
),
"w"
)
self
.
callbacks
=
callbacks
if
callbacks
is
not
None
else
[]
for
callback
in
self
.
callbacks
:
callback
.
build
(
self
.
model
,
self
.
mutator
,
self
)
...
...
@@ -168,3 +173,22 @@ class Trainer(BaseTrainer):
Return trainer checkpoint.
"""
raise
NotImplementedError
(
"Not implemented yet"
)
def
enable_visualization
(
self
):
"""
Enable visualization. Write graph and training log to folder ``logs/<timestamp>``.
"""
sample
=
None
for
x
,
_
in
self
.
train_loader
:
sample
=
x
.
to
(
self
.
device
)[:
2
]
break
if
sample
is
None
:
_logger
.
warning
(
"Sample is %s."
,
sample
)
_logger
.
info
(
"Creating graph json, writing to %s. Visualization enabled."
,
self
.
log_dir
)
with
open
(
os
.
path
.
join
(
self
.
log_dir
,
"graph.json"
),
"w"
)
as
f
:
json
.
dump
(
self
.
mutator
.
graph
(
sample
),
f
)
self
.
visualization_enabled
=
True
def
_write_graph_status
(
self
):
if
hasattr
(
self
,
"visualization_enabled"
)
and
self
.
visualization_enabled
:
print
(
json
.
dumps
(
self
.
mutator
.
status
()),
file
=
self
.
status_writer
,
flush
=
True
)
src/webui/package.json
View file @
67287997
...
...
@@ -92,5 +92,8 @@
"presets"
:
[
"react-app"
]
},
"resolutions"
:
{
"npm"
:
">=6.14.4"
}
}
src/webui/src/App.tsx
View file @
67287997
...
...
@@ -17,8 +17,9 @@ interface AppState {
}
class
App
extends
React
.
Component
<
{},
AppState
>
{
private
timerId
!
:
number
|
null
;
private
timerId
!
:
number
|
undefined
;
private
dataFormatimer
!
:
number
;
private
firstLoad
:
boolean
=
false
;
// when click refresh selector options
constructor
(
props
:
{})
{
super
(
props
);
...
...
@@ -66,14 +67,20 @@ class App extends React.Component<{}, AppState> {
}
}
}
changeInterval
=
(
interval
:
number
):
void
=>
{
this
.
setState
({
interval
});
if
(
this
.
timerId
===
null
&&
interval
!==
0
)
{
window
.
setTimeout
(
this
.
refresh
);
}
else
if
(
this
.
timerId
!==
null
&&
interval
===
0
)
{
window
.
clearTimeout
(
this
.
timerId
);
window
.
clearTimeout
(
this
.
timerId
);
if
(
interval
===
0
)
{
return
;
}
// setState will trigger page refresh at once.
// setState is asyc, interval not update to (this.state.interval) at once.
this
.
setState
({
interval
},
()
=>
{
this
.
firstLoad
=
true
;
this
.
refresh
();
});
}
// TODO: use local storage
...
...
@@ -123,24 +130,30 @@ class App extends React.Component<{}, AppState> {
}
private
refresh
=
async
():
Promise
<
void
>
=>
{
const
[
experimentUpdated
,
trialsUpdated
]
=
await
Promise
.
all
([
EXPERIMENT
.
update
(),
TRIALS
.
update
()]);
if
(
experimentUpdated
)
{
this
.
setState
(
state
=>
({
experimentUpdateBroadcast
:
state
.
experimentUpdateBroadcast
+
1
}));
}
if
(
trialsUpdated
)
{
this
.
setState
(
state
=>
({
trialsUpdateBroadcast
:
state
.
trialsUpdateBroadcast
+
1
}));
// resolve this question: 10s -> 20s, page refresh twice.
// only refresh this page after clicking the refresh options
if
(
this
.
firstLoad
!==
true
)
{
const
[
experimentUpdated
,
trialsUpdated
]
=
await
Promise
.
all
([
EXPERIMENT
.
update
(),
TRIALS
.
update
()]);
if
(
experimentUpdated
)
{
this
.
setState
(
state
=>
({
experimentUpdateBroadcast
:
state
.
experimentUpdateBroadcast
+
1
}));
}
if
(
trialsUpdated
)
{
this
.
setState
(
state
=>
({
trialsUpdateBroadcast
:
state
.
trialsUpdateBroadcast
+
1
}));
}
}
else
{
this
.
firstLoad
=
false
;
}
if
([
'
DONE
'
,
'
ERROR
'
,
'
STOPPED
'
].
includes
(
EXPERIMENT
.
status
))
{
// experiment finished, refresh once more to ensure consistency
if
(
this
.
state
.
interval
>
0
)
{
this
.
setState
({
interval
:
0
});
this
.
lastRefresh
();
}
}
else
if
(
this
.
state
.
interval
!==
0
)
{
this
.
timerId
=
window
.
setTimeout
(
this
.
refresh
,
this
.
state
.
interval
*
1000
);
this
.
setState
({
interval
:
0
});
this
.
lastRefresh
();
return
;
}
this
.
timerId
=
window
.
setTimeout
(
this
.
refresh
,
this
.
state
.
interval
*
1000
);
}
public
async
lastRefresh
():
Promise
<
void
>
{
...
...
src/webui/yarn.lock
View file @
67287997
This diff is collapsed.
Click to expand it.
test/config/naive_trial/naive_trial.py
0 → 100644
View file @
67287997
import
time
import
nni
if
__name__
==
'__main__'
:
print
(
'trial start'
)
params
=
nni
.
get_next_parameter
()
print
(
'params:'
,
params
)
epochs
=
2
for
i
in
range
(
epochs
):
nni
.
report_intermediate_result
(
0.1
*
(
i
+
1
))
time
.
sleep
(
1
)
nni
.
report_final_result
(
0.8
)
print
(
'trial done'
)
test/config/pr_tests.yml
View file @
67287997
...
...
@@ -70,12 +70,18 @@ testCases:
config
:
maxTrialNum
:
2
trialConcurrency
:
2
trial
:
codeDir
:
../naive_trial
command
:
python3 naive_trial.py
-
name
:
assessor-medianstop
configFile
:
test/config/assessors/medianstop.yml
config
:
maxTrialNum
:
2
trialConcurrency
:
2
trial
:
codeDir
:
../naive_trial
command
:
python3 naive_trial.py
#########################################################################
# nni tuners test
...
...
@@ -89,7 +95,7 @@ testCases:
searchSpacePath
:
../naive_trial/search_space.json
trial
:
codeDir
:
../naive_trial
command
:
python3 trial.py
command
:
python3
naive_
trial.py
-
name
:
tuner-evolution
configFile
:
test/config/tuners/evolution.yml
...
...
@@ -100,7 +106,7 @@ testCases:
searchSpacePath
:
../naive_trial/search_space.json
trial
:
codeDir
:
../naive_trial
command
:
python3 trial.py
command
:
python3
naive_
trial.py
-
name
:
tuner-random
configFile
:
test/config/tuners/random.yml
...
...
@@ -111,7 +117,7 @@ testCases:
searchSpacePath
:
../naive_trial/search_space.json
trial
:
codeDir
:
../naive_trial
command
:
python3 trial.py
command
:
python3
naive_
trial.py
-
name
:
tuner-tpe
configFile
:
test/config/tuners/tpe.yml
...
...
@@ -122,7 +128,7 @@ testCases:
searchSpacePath
:
../naive_trial/search_space.json
trial
:
codeDir
:
../naive_trial
command
:
python3 trial.py
command
:
python3
naive_
trial.py
-
name
:
tuner-batch
configFile
:
test/config/tuners/batch.yml
...
...
@@ -144,7 +150,7 @@ testCases:
searchSpacePath
:
../naive_trial/search_space.json
trial
:
codeDir
:
../naive_trial
command
:
python3 trial.py
command
:
python3
naive_
trial.py
-
name
:
tuner-grid
configFile
:
test/config/tuners/gridsearch.yml
...
...
test/pipelines/pipelines-it-local-windows.yml
View file @
67287997
...
...
@@ -10,7 +10,7 @@ jobs:
python -m pip install scikit-learn==0.20.0 --user
python -m pip install keras==2.1.6 --user
python -m pip install torchvision===0.4.1 torch===1.3.1 -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install tensorflow-gpu==1.1
1.0
--user
python -m pip install tensorflow-gpu==1.1
5.2
--user
displayName
:
'
Install
dependencies
for
integration
tests'
-
script
:
|
cd test
...
...
test/pipelines/pipelines-it-pai-windows.yml
View file @
67287997
...
...
@@ -63,6 +63,7 @@ jobs:
cd test
set PATH=$(ENV_PATH)
python --version
python nni_test/nnitest/generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) --nni_docker_image $(docker_image) --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
mount -o anon $(pai_nfs_uri) $(local_nfs_uri)
python nni_test/nnitest/generate_ts_config.py --ts pai --pai_token $(pai_token) --pai_host $(pai_host) --pai_user $(pai_user) --nni_docker_image $(docker_image) --pai_storage_plugin $(pai_storage_plugin) --nni_manager_nfs_mount_path $(nni_manager_nfs_mount_path) --container_nfs_mount_path $(container_nfs_mount_path) --nni_manager_ip $(nni_manager_ip)
python nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts pai --exclude multi-phase
displayName
:
'
Examples
and
advanced
features
tests
on
pai'
\ No newline at end of file
test/pipelines/pipelines-it-remote-windows.yml
View file @
67287997
...
...
@@ -52,9 +52,3 @@ jobs:
runOptions
:
commands
commands
:
python3 /tmp/nnitest/$(Build.BuildId)/nni-remote/test/nni_test/nnitest/remote_docker.py --mode stop --name $(Build.BuildId) --os windows
displayName
:
'
Stop
docker'
-
task
:
SSH@0
inputs
:
sshEndpoint
:
$(end_point)
runOptions
:
commands
commands
:
sudo rm -rf /tmp/nnitest/$(Build.BuildId)
displayName
:
'
Clean
the
remote
files'
test/scripts/model_compression.sh
View file @
67287997
...
...
@@ -19,6 +19,15 @@ do
python3 model_prune_torch.py
--pruner_name
$name
--pretrain_epochs
1
--prune_epochs
1
done
echo
'testing level pruner pruning'
python3 model_prune_torch.py
--pruner_name
level
--pretrain_epochs
1
--prune_epochs
1
echo
'testing agp pruning'
python3 model_prune_torch.py
--pruner_name
agp
--pretrain_epochs
1
--prune_epochs
2
echo
'testing mean_activation pruning'
python3 model_prune_torch.py
--pruner_name
mean_activation
--pretrain_epochs
1
--prune_epochs
1
#echo "testing lottery ticket pruning..."
#python3 lottery_torch_mnist_fc.py
...
...
test/scripts/nas.sh
View file @
67287997
...
...
@@ -28,9 +28,10 @@ cd $EXAMPLE_DIR/enas
python3 search.py
--search-for
macro
--epochs
1
python3 search.py
--search-for
micro
--epochs
1
echo
"testing naive..."
cd
$EXAMPLE_DIR
/naive
python3 train.py
#disabled for now
#echo "testing naive..."
#cd $EXAMPLE_DIR/naive
#python3 train.py
echo
"testing pdarts..."
cd
$EXAMPLE_DIR
/pdarts
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment