Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
67287997
Unverified
Commit
67287997
authored
Apr 16, 2020
by
SparkSnail
Committed by
GitHub
Apr 16, 2020
Browse files
Merge pull request #241 from microsoft/master
merge master
parents
b4773e1e
f8d42a33
Changes
74
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
909 additions
and
668 deletions
+909
-668
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+5
-5
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
+2
-0
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+24
-0
src/webui/package.json
src/webui/package.json
+3
-0
src/webui/src/App.tsx
src/webui/src/App.tsx
+33
-20
src/webui/yarn.lock
src/webui/yarn.lock
+800
-627
test/config/naive_trial/naive_trial.py
test/config/naive_trial/naive_trial.py
+14
-0
test/config/pr_tests.yml
test/config/pr_tests.yml
+11
-5
test/pipelines/pipelines-it-local-windows.yml
test/pipelines/pipelines-it-local-windows.yml
+1
-1
test/pipelines/pipelines-it-pai-windows.yml
test/pipelines/pipelines-it-pai-windows.yml
+2
-1
test/pipelines/pipelines-it-remote-windows.yml
test/pipelines/pipelines-it-remote-windows.yml
+0
-6
test/scripts/model_compression.sh
test/scripts/model_compression.sh
+9
-0
test/scripts/nas.sh
test/scripts/nas.sh
+4
-3
No files found.
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
67287997
...
...
@@ -212,7 +212,7 @@ class AGP_Pruner(Pruner):
if
epoch
>
0
:
self
.
now_epoch
=
epoch
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
if_calculated
.
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callabl
e
wrapper
.
if_calculated
=
Fals
e
class
SlimPruner
(
Pruner
):
"""
...
...
@@ -329,10 +329,6 @@ class LotteryTicketPruner(Pruner):
reset_weights : bool
Whether reset weights and optimizer at the beginning of each round.
"""
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
curr_prune_iteration
=
None
self
.
prune_iterations
=
config_list
[
0
][
'prune_iterations'
]
# save init weights and optimizer
self
.
reset_weights
=
reset_weights
if
self
.
reset_weights
:
...
...
@@ -344,6 +340,10 @@ class LotteryTicketPruner(Pruner):
if
lr_scheduler
is
not
None
:
self
.
_scheduler_state
=
copy
.
deepcopy
(
lr_scheduler
.
state_dict
())
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
curr_prune_iteration
=
None
self
.
prune_iterations
=
config_list
[
0
][
'prune_iterations'
]
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
...
...
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
67287997
...
...
@@ -129,6 +129,7 @@ class DartsTrainer(Trainer):
self
.
mutator
.
reset
()
logits
=
self
.
model
(
X
)
loss
=
self
.
loss
(
logits
,
y
)
self
.
_write_graph_status
()
return
logits
,
loss
def
_backward
(
self
,
val_X
,
val_y
):
...
...
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
View file @
67287997
...
...
@@ -126,6 +126,7 @@ class EnasTrainer(Trainer):
with
torch
.
no_grad
():
self
.
mutator
.
reset
()
self
.
_write_graph_status
()
logits
=
self
.
model
(
x
)
if
isinstance
(
logits
,
tuple
):
...
...
@@ -159,6 +160,7 @@ class EnasTrainer(Trainer):
self
.
mutator
.
reset
()
with
torch
.
no_grad
():
logits
=
self
.
model
(
x
)
self
.
_write_graph_status
()
metrics
=
self
.
metrics
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
if
self
.
entropy_weight
:
...
...
src/sdk/pynni/nni/nas/pytorch/trainer.py
View file @
67287997
...
...
@@ -3,6 +3,8 @@
import
json
import
logging
import
os
import
time
from
abc
import
abstractmethod
import
torch
...
...
@@ -90,6 +92,9 @@ class Trainer(BaseTrainer):
self
.
batch_size
=
batch_size
self
.
workers
=
workers
self
.
log_frequency
=
log_frequency
self
.
log_dir
=
os
.
path
.
join
(
"logs"
,
str
(
time
.
time
()))
os
.
makedirs
(
self
.
log_dir
,
exist_ok
=
True
)
self
.
status_writer
=
open
(
os
.
path
.
join
(
self
.
log_dir
,
"log"
),
"w"
)
self
.
callbacks
=
callbacks
if
callbacks
is
not
None
else
[]
for
callback
in
self
.
callbacks
:
callback
.
build
(
self
.
model
,
self
.
mutator
,
self
)
...
...
@@ -168,3 +173,22 @@ class Trainer(BaseTrainer):
Return trainer checkpoint.
"""
raise
NotImplementedError
(
"Not implemented yet"
)
def
enable_visualization
(
self
):
"""
Enable visualization. Write graph and training log to folder ``logs/<timestamp>``.
"""
sample
=
None
for
x
,
_
in
self
.
train_loader
:
sample
=
x
.
to
(
self
.
device
)[:
2
]
break
if
sample
is
None
:
_logger
.
warning
(
"Sample is %s."
,
sample
)
_logger
.
info
(
"Creating graph json, writing to %s. Visualization enabled."
,
self
.
log_dir
)
with
open
(
os
.
path
.
join
(
self
.
log_dir
,
"graph.json"
),
"w"
)
as
f
:
json
.
dump
(
self
.
mutator
.
graph
(
sample
),
f
)
self
.
visualization_enabled
=
True
def
_write_graph_status
(
self
):
if
hasattr
(
self
,
"visualization_enabled"
)
and
self
.
visualization_enabled
:
print
(
json
.
dumps
(
self
.
mutator
.
status
()),
file
=
self
.
status_writer
,
flush
=
True
)
src/webui/package.json
View file @
67287997
...
...
@@ -92,5 +92,8 @@
"presets"
:
[
"react-app"
]
},
"resolutions"
:
{
"npm"
:
">=6.14.4"
}
}
src/webui/src/App.tsx
View file @
67287997
...
...
@@ -17,8 +17,9 @@ interface AppState {
}
class
App
extends
React
.
Component
<
{},
AppState
>
{
private
timerId
!
:
number
|
null
;
private
timerId
!
:
number
|
undefined
;
private
dataFormatimer
!
:
number
;
private
firstLoad
:
boolean
=
false
;
// when click refresh selector options
constructor
(
props
:
{})
{
super
(
props
);
...
...
@@ -66,14 +67,20 @@ class App extends React.Component<{}, AppState> {
}
}
}
changeInterval
=
(
interval
:
number
):
void
=>
{
this
.
setState
({
interval
});
if
(
this
.
timerId
===
null
&&
interval
!==
0
)
{
window
.
setTimeout
(
this
.
refresh
);
}
else
if
(
this
.
timerId
!==
null
&&
interval
===
0
)
{
window
.
clearTimeout
(
this
.
timerId
);
window
.
clearTimeout
(
this
.
timerId
);
if
(
interval
===
0
)
{
return
;
}
// setState will trigger page refresh at once.
// setState is asyc, interval not update to (this.state.interval) at once.
this
.
setState
({
interval
},
()
=>
{
this
.
firstLoad
=
true
;
this
.
refresh
();
});
}
// TODO: use local storage
...
...
@@ -123,24 +130,30 @@ class App extends React.Component<{}, AppState> {
}
private
refresh
=
async
():
Promise
<
void
>
=>
{
const
[
experimentUpdated
,
trialsUpdated
]
=
await
Promise
.
all
([
EXPERIMENT
.
update
(),
TRIALS
.
update
()]);
if
(
experimentUpdated
)
{
this
.
setState
(
state
=>
({
experimentUpdateBroadcast
:
state
.
experimentUpdateBroadcast
+
1
}));
}
if
(
trialsUpdated
)
{
this
.
setState
(
state
=>
({
trialsUpdateBroadcast
:
state
.
trialsUpdateBroadcast
+
1
}));
// resolve this question: 10s -> 20s, page refresh twice.
// only refresh this page after clicking the refresh options
if
(
this
.
firstLoad
!==
true
)
{
const
[
experimentUpdated
,
trialsUpdated
]
=
await
Promise
.
all
([
EXPERIMENT
.
update
(),
TRIALS
.
update
()]);
if
(
experimentUpdated
)
{
this
.
setState
(
state
=>
({
experimentUpdateBroadcast
:
state
.
experimentUpdateBroadcast
+
1
}));
}
if
(
trialsUpdated
)
{
this
.
setState
(
state
=>
({
trialsUpdateBroadcast
:
state
.
trialsUpdateBroadcast
+
1
}));
}
}
else
{
this
.
firstLoad
=
false
;
}
if
([
'
DONE
'
,
'
ERROR
'
,
'
STOPPED
'
].
includes
(
EXPERIMENT
.
status
))
{
// experiment finished, refresh once more to ensure consistency
if
(
this
.
state
.
interval
>
0
)
{
this
.
setState
({
interval
:
0
});
this
.
lastRefresh
();
}
}
else
if
(
this
.
state
.
interval
!==
0
)
{
this
.
timerId
=
window
.
setTimeout
(
this
.
refresh
,
this
.
state
.
interval
*
1000
);
this
.
setState
({
interval
:
0
});
this
.
lastRefresh
();
return
;
}
this
.
timerId
=
window
.
setTimeout
(
this
.
refresh
,
this
.
state
.
interval
*
1000
);
}
public
async
lastRefresh
():
Promise
<
void
>
{
...
...
src/webui/yarn.lock
View file @
67287997
This diff is collapsed.
Click to expand it.
test/config/naive_trial/naive_trial.py
0 → 100644
View file @
67287997
import
time
import
nni
if
__name__
==
'__main__'
:
print
(
'trial start'
)
params
=
nni
.
get_next_parameter
()
print
(
'params:'
,
params
)
epochs
=
2
for
i
in
range
(
epochs
):
nni
.
report_intermediate_result
(
0.1
*
(
i
+
1
))
time
.
sleep
(
1
)
nni
.
report_final_result
(
0.8
)
print
(
'trial done'
)
test/config/pr_tests.yml
View file @
67287997
...
...
@@ -70,12 +70,18 @@ testCases:
config
:
maxTrialNum
:
2
trialConcurrency
:
2
trial
:
codeDir
:
../naive_trial
command
:
python3 naive_trial.py
-
name
:
assessor-medianstop
configFile
:
test/config/assessors/medianstop.yml
config
:
maxTrialNum
:
2
trialConcurrency
:
2
trial
:
codeDir
:
../naive_trial
command
:
python3 naive_trial.py
#########################################################################
# nni tuners test
...
...
@@ -89,7 +95,7 @@ testCases:
searchSpacePath
:
../naive_trial/search_space.json
trial
:
codeDir
:
../naive_trial
command
:
python3 trial.py
command
:
python3
naive_
trial.py
-
name
:
tuner-evolution
configFile
:
test/config/tuners/evolution.yml
...
...
@@ -100,7 +106,7 @@ testCases:
searchSpacePath
:
../naive_trial/search_space.json
trial
:
codeDir
:
../naive_trial
command
:
python3 trial.py
command
:
python3
naive_
trial.py
-
name
:
tuner-random
configFile
:
test/config/tuners/random.yml
...
...
@@ -111,7 +117,7 @@ testCases:
searchSpacePath
:
../naive_trial/search_space.json
trial
:
codeDir
:
../naive_trial
command
:
python3 trial.py
command
:
python3
naive_
trial.py
-
name
:
tuner-tpe
configFile
:
test/config/tuners/tpe.yml
...
...
@@ -122,7 +128,7 @@ testCases:
searchSpacePath
:
../naive_trial/search_space.json
trial
:
codeDir
:
../naive_trial
command
:
python3 trial.py
command
:
python3
naive_
trial.py
-
name
:
tuner-batch
configFile
:
test/config/tuners/batch.yml
...
...
@@ -144,7 +150,7 @@ testCases:
searchSpacePath
:
../naive_trial/search_space.json
trial
:
codeDir
:
../naive_trial
command
:
python3 trial.py
command
:
python3
naive_
trial.py
-
name
:
tuner-grid
configFile
:
test/config/tuners/gridsearch.yml
...
...
test/pipelines/pipelines-it-local-windows.yml
View file @
67287997
...
...
@@ -10,7 +10,7 @@ jobs:
python -m pip install scikit-learn==0.20.0 --user
python -m pip install keras==2.1.6 --user
python -m pip install torchvision===0.4.1 torch===1.3.1 -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install tensorflow-gpu==1.1
1.0
--user
python -m pip install tensorflow-gpu==1.1
5.2
--user
displayName
:
'
Install
dependencies
for
integration
tests'
-
script
:
|
cd test
...
...
test/pipelines/pipelines-it-pai-windows.yml
View file @
67287997
...
...
@@ -63,6 +63,7 @@ jobs:
cd test
set PATH=$(ENV_PATH)
python --version
python nni_test/nnitest/generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) --nni_docker_image $(docker_image) --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
mount -o anon $(pai_nfs_uri) $(local_nfs_uri)
python nni_test/nnitest/generate_ts_config.py --ts pai --pai_token $(pai_token) --pai_host $(pai_host) --pai_user $(pai_user) --nni_docker_image $(docker_image) --pai_storage_plugin $(pai_storage_plugin) --nni_manager_nfs_mount_path $(nni_manager_nfs_mount_path) --container_nfs_mount_path $(container_nfs_mount_path) --nni_manager_ip $(nni_manager_ip)
python nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts pai --exclude multi-phase
displayName
:
'
Examples
and
advanced
features
tests
on
pai'
\ No newline at end of file
test/pipelines/pipelines-it-remote-windows.yml
View file @
67287997
...
...
@@ -52,9 +52,3 @@ jobs:
runOptions
:
commands
commands
:
python3 /tmp/nnitest/$(Build.BuildId)/nni-remote/test/nni_test/nnitest/remote_docker.py --mode stop --name $(Build.BuildId) --os windows
displayName
:
'
Stop
docker'
-
task
:
SSH@0
inputs
:
sshEndpoint
:
$(end_point)
runOptions
:
commands
commands
:
sudo rm -rf /tmp/nnitest/$(Build.BuildId)
displayName
:
'
Clean
the
remote
files'
test/scripts/model_compression.sh
View file @
67287997
...
...
@@ -19,6 +19,15 @@ do
python3 model_prune_torch.py
--pruner_name
$name
--pretrain_epochs
1
--prune_epochs
1
done
echo
'testing level pruner pruning'
python3 model_prune_torch.py
--pruner_name
level
--pretrain_epochs
1
--prune_epochs
1
echo
'testing agp pruning'
python3 model_prune_torch.py
--pruner_name
agp
--pretrain_epochs
1
--prune_epochs
2
echo
'testing mean_activation pruning'
python3 model_prune_torch.py
--pruner_name
mean_activation
--pretrain_epochs
1
--prune_epochs
1
#echo "testing lottery ticket pruning..."
#python3 lottery_torch_mnist_fc.py
...
...
test/scripts/nas.sh
View file @
67287997
...
...
@@ -28,9 +28,10 @@ cd $EXAMPLE_DIR/enas
python3 search.py
--search-for
macro
--epochs
1
python3 search.py
--search-for
micro
--epochs
1
echo
"testing naive..."
cd
$EXAMPLE_DIR
/naive
python3 train.py
#disabled for now
#echo "testing naive..."
#cd $EXAMPLE_DIR/naive
#python3 train.py
echo
"testing pdarts..."
cd
$EXAMPLE_DIR
/pdarts
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment