Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
32efaa36
Unverified
Commit
32efaa36
authored
Dec 10, 2019
by
SparkSnail
Committed by
GitHub
Dec 10, 2019
Browse files
Merge pull request #219 from microsoft/master
merge master
parents
cd3a912a
97b258b0
Changes
57
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
2250 additions
and
158 deletions
+2250
-158
src/nni_manager/training_service/kubernetes/kubeflow/kubeflowTrainingService.ts
...ng_service/kubernetes/kubeflow/kubeflowTrainingService.ts
+3
-3
src/nni_manager/tslint.json
src/nni_manager/tslint.json
+0
-25
src/nni_manager/yarn.lock
src/nni_manager/yarn.lock
+2134
-45
src/sdk/pynni/nni/metis_tuner/requirments.txt
src/sdk/pynni/nni/metis_tuner/requirments.txt
+1
-1
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+2
-1
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+1
-1
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
+55
-44
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
+18
-9
src/sdk/pynni/requirements.txt
src/sdk/pynni/requirements.txt
+1
-1
src/webui/package.json
src/webui/package.json
+2
-1
src/webui/src/components/trial-detail/TableList.tsx
src/webui/src/components/trial-detail/TableList.tsx
+12
-11
src/webui/src/static/style/table.scss
src/webui/src/static/style/table.scss
+6
-8
src/webui/yarn.lock
src/webui/yarn.lock
+3
-3
test/metrics_test/trial.py
test/metrics_test/trial.py
+1
-0
test/naive_test/naive_trial.py
test/naive_test/naive_trial.py
+1
-0
test/utils.py
test/utils.py
+1
-1
tools/bash-completion
tools/bash-completion
+9
-4
No files found.
src/nni_manager/training_service/kubernetes/kubeflow/kubeflowTrainingService.ts
View file @
32efaa36
...
...
@@ -21,7 +21,7 @@ import { AzureStorageClientUtility } from '../azureStorageClientUtils';
import
{
NFSConfig
}
from
'
../kubernetesConfig
'
;
import
{
KubernetesTrialJobDetail
}
from
'
../kubernetesData
'
;
import
{
KubernetesTrainingService
}
from
'
../kubernetesTrainingService
'
;
import
{
KubeflowOperatorClient
}
from
'
./kubeflowApiClient
'
;
import
{
KubeflowOperatorClient
Factory
}
from
'
./kubeflowApiClient
'
;
import
{
KubeflowClusterConfig
,
KubeflowClusterConfigAzure
,
KubeflowClusterConfigFactory
,
KubeflowClusterConfigNFS
,
KubeflowTrialConfig
,
KubeflowTrialConfigFactory
,
KubeflowTrialConfigPytorch
,
KubeflowTrialConfigTensorflow
}
from
'
./kubeflowConfig
'
;
...
...
@@ -136,8 +136,8 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
nfsKubeflowClusterConfig
.
nfs
.
path
);
}
this
.
kubernetesCRDClient
=
KubeflowOperatorClient
.
generateOperatorClient
(
this
.
kubeflowClusterConfig
.
operator
,
this
.
kubeflowClusterConfig
.
apiVersion
);
this
.
kubernetesCRDClient
=
KubeflowOperatorClient
Factory
.
createClient
(
this
.
kubeflowClusterConfig
.
operator
,
this
.
kubeflowClusterConfig
.
apiVersion
);
break
;
case
TrialConfigMetadataKey
.
TRIAL_CONFIG
:
...
...
src/nni_manager/tslint.json
deleted
100644 → 0
View file @
cd3a912a
{
"defaultSeverity"
:
"error"
,
"extends"
:
"tslint-microsoft-contrib"
,
"jsRules"
:
{},
"rules"
:
{
"no-relative-imports"
:
false
,
"export-name"
:
false
,
"interface-name"
:
[
true
,
"never-prefix"
],
"no-increment-decrement"
:
false
,
"promise-function-async"
:
false
,
"no-console"
:
[
true
,
"log"
],
"no-multiline-string"
:
false
,
"no-suspicious-comment"
:
false
,
"no-backbone-get-set-outside-model"
:
false
,
"max-classes-per-file"
:
false
},
"rulesDirectory"
:
[],
"linterOptions"
:
{
"exclude"
:
[
"training_service/test/*"
,
"rest_server/test/*"
,
"core/test/*"
]
}
}
\ No newline at end of file
src/nni_manager/yarn.lock
View file @
32efaa36
This diff is collapsed.
Click to expand it.
src/sdk/pynni/nni/metis_tuner/requirments.txt
View file @
32efaa36
sklearn
\ No newline at end of file
scikit-learn==0.20
\ No newline at end of file
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
32efaa36
...
...
@@ -18,10 +18,11 @@ class DartsTrainer(Trainer):
def
__init__
(
self
,
model
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
arc_learning_rate
=
3.0E-4
,
unrolled
=
Tru
e
):
callbacks
=
None
,
arc_learning_rate
=
3.0E-4
,
unrolled
=
Fals
e
):
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
DartsMutator
(
model
),
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
self
.
ctrl_optim
=
torch
.
optim
.
Adam
(
self
.
mutator
.
parameters
(),
arc_learning_rate
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
self
.
unrolled
=
unrolled
...
...
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
32efaa36
...
...
@@ -111,7 +111,7 @@ class Mutator(BaseMutator):
if
"BoolTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
elif
"FloatTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)]
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
else
:
raise
ValueError
(
"Unrecognized mask"
)
return
out
...
...
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
View file @
32efaa36
...
...
@@ -4,13 +4,18 @@
import
copy
import
numpy
as
np
import
torch.nn.functional
as
F
import
torch
from
torch
import
nn
from
nni.nas.pytorch.darts
import
DartsMutator
from
nni.nas.pytorch.mutables
import
LayerChoice
class
PdartsMutator
(
DartsMutator
):
"""
It works with PdartsTrainer to calculate ops weights,
and drop weights in different PDARTS epochs.
"""
def
__init__
(
self
,
model
,
pdarts_epoch_index
,
pdarts_num_to_drop
,
switches
=
{}):
self
.
pdarts_epoch_index
=
pdarts_epoch_index
...
...
@@ -22,60 +27,66 @@ class PdartsMutator(DartsMutator):
super
(
PdartsMutator
,
self
).
__init__
(
model
)
# this loop go through mutables with different keys,
# it's mainly to update length of choices.
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
switches
=
self
.
switches
.
get
(
mutable
.
key
,
[
True
for
j
in
range
(
mutable
.
length
)])
choices
=
self
.
choices
[
mutable
.
key
]
for
index
in
range
(
len
(
switches
)
-
1
,
-
1
,
-
1
):
if
switches
[
index
]
==
False
:
del
(
mutable
.
choices
[
index
])
mutable
.
length
-=
1
operations_count
=
np
.
sum
(
switches
)
# +1 and -1 are caused by zero operation in darts network
# the zero operation is not in choices list in network, but its weight are in,
# so it needs one more weights and switch for zero.
self
.
choices
[
mutable
.
key
]
=
nn
.
Parameter
(
1.0E-3
*
torch
.
randn
(
operations_count
+
1
))
self
.
switches
[
mutable
.
key
]
=
switches
def
drop_paths
(
self
):
for
key
in
self
.
switches
:
prob
=
F
.
softmax
(
self
.
choices
[
key
],
dim
=-
1
).
data
.
cpu
().
numpy
()
# update LayerChoice instances in model,
# it's physically remove dropped choices operations.
for
module
in
self
.
model
.
modules
():
if
isinstance
(
module
,
LayerChoice
):
switches
=
self
.
switches
.
get
(
module
.
key
)
choices
=
self
.
choices
[
module
.
key
]
if
len
(
module
.
choices
)
>
len
(
choices
):
# from last to first, so that it won't effect previous indexes after removed one.
for
index
in
range
(
len
(
switches
)
-
1
,
-
1
,
-
1
):
if
switches
[
index
]
==
False
:
del
(
module
.
choices
[
index
])
module
.
length
-=
1
def
sample_final
(
self
):
results
=
super
().
sample_final
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
# As some operations are dropped physically,
# so it needs to fill back false to track dropped operations.
trained_result
=
results
[
mutable
.
key
]
trained_index
=
0
switches
=
self
.
switches
[
mutable
.
key
]
result
=
torch
.
Tensor
(
switches
).
bool
()
for
index
in
range
(
len
(
result
)):
if
result
[
index
]:
result
[
index
]
=
trained_result
[
trained_index
]
trained_index
+=
1
results
[
mutable
.
key
]
=
result
return
results
switches
=
self
.
switches
[
key
]
def
drop_paths
(
self
):
"""
This method is called when a PDARTS epoch is finished.
It prepares switches for next epoch.
candidate operations with False switch will be doppped in next epoch.
"""
all_switches
=
copy
.
deepcopy
(
self
.
switches
)
for
key
in
all_switches
:
switches
=
all_switches
[
key
]
idxs
=
[]
for
j
in
range
(
len
(
switches
)):
if
switches
[
j
]:
idxs
.
append
(
j
)
if
self
.
pdarts_epoch_index
==
len
(
self
.
pdarts_num_to_drop
)
-
1
:
# for the last stage, drop all Zero operations
drop
=
self
.
get_min_k_no_zero
(
prob
,
idxs
,
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
])
else
:
drop
=
self
.
get_min_k
(
prob
,
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
])
sorted_weights
=
self
.
choices
[
key
].
data
.
cpu
().
numpy
()[:
-
1
]
drop
=
np
.
argsort
(
sorted_weights
)[:
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
]]
for
idx
in
drop
:
switches
[
idxs
[
idx
]]
=
False
return
self
.
switches
def
get_min_k
(
self
,
input_in
,
k
):
index
=
[]
for
_
in
range
(
k
):
idx
=
np
.
argmin
(
input
)
index
.
append
(
idx
)
return
index
def
get_min_k_no_zero
(
self
,
w_in
,
idxs
,
k
):
w
=
copy
.
deepcopy
(
w_in
)
index
=
[]
if
0
in
idxs
:
zf
=
True
else
:
zf
=
False
if
zf
:
w
=
w
[
1
:]
index
.
append
(
0
)
k
=
k
-
1
for
_
in
range
(
k
):
idx
=
np
.
argmin
(
w
)
w
[
idx
]
=
1
if
zf
:
idx
=
idx
+
1
index
.
append
(
idx
)
return
index
return
all_switches
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
View file @
32efaa36
...
...
@@ -14,14 +14,22 @@ logger = logging.getLogger(__name__)
class
PdartsTrainer
(
BaseTrainer
):
def
__init__
(
self
,
model_creator
,
layers
,
metrics
,
"""
This trainer implements the PDARTS algorithm.
PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network.
This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows.
pdarts_num_layers means how many layers more than first epoch.
pdarts_num_to_drop means how many candidate operations should be dropped in each epoch.
So that the grew network can in similar size.
"""
def
__init__
(
self
,
model_creator
,
init_layers
,
metrics
,
num_epochs
,
dataset_train
,
dataset_valid
,
pdarts_num_layers
=
[
0
,
6
,
12
],
pdarts_num_to_drop
=
[
3
,
2
,
2
],
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
):
pdarts_num_layers
=
[
0
,
6
,
12
],
pdarts_num_to_drop
=
[
3
,
2
,
1
],
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
unrolled
=
False
):
super
(
PdartsTrainer
,
self
).
__init__
()
self
.
model_creator
=
model_creator
self
.
layers
=
layers
self
.
init_
layers
=
init_
layers
self
.
pdarts_num_layers
=
pdarts_num_layers
self
.
pdarts_num_to_drop
=
pdarts_num_to_drop
self
.
pdarts_epoch
=
len
(
pdarts_num_to_drop
)
...
...
@@ -33,16 +41,17 @@ class PdartsTrainer(BaseTrainer):
"batch_size"
:
batch_size
,
"workers"
:
workers
,
"device"
:
device
,
"log_frequency"
:
log_frequency
"log_frequency"
:
log_frequency
,
"unrolled"
:
unrolled
}
self
.
callbacks
=
callbacks
if
callbacks
is
not
None
else
[]
def
train
(
self
):
layers
=
self
.
layers
switches
=
None
for
epoch
in
range
(
self
.
pdarts_epoch
):
layers
=
self
.
layers
+
self
.
pdarts_num_layers
[
epoch
]
layers
=
self
.
init_
layers
+
self
.
pdarts_num_layers
[
epoch
]
model
,
criterion
,
optim
,
lr_scheduler
=
self
.
model_creator
(
layers
)
self
.
mutator
=
PdartsMutator
(
model
,
epoch
,
self
.
pdarts_num_to_drop
,
switches
)
...
...
@@ -66,7 +75,7 @@ class PdartsTrainer(BaseTrainer):
callback
.
on_epoch_end
(
epoch
)
def
validate
(
self
):
self
.
model
.
validate
()
self
.
trainer
.
validate
()
def
export
(
self
,
file
):
mutator_export
=
self
.
mutator
.
export
()
...
...
src/sdk/pynni/requirements.txt
View file @
32efaa36
...
...
@@ -7,4 +7,4 @@ scipy
hyperopt==0.1.2
# metis tuner
s
klearn
s
cikit-learn==0.20
src/webui/package.json
View file @
32efaa36
...
...
@@ -66,7 +66,8 @@
},
"resolutions"
:
{
"@types/react"
:
"16.4.17"
,
"js-yaml"
:
"^3.13.1"
"js-yaml"
:
"^3.13.1"
,
"serialize-javascript"
:
"^2.1.1"
},
"babel"
:
{
"presets"
:
[
...
...
src/webui/src/components/trial-detail/TableList.tsx
View file @
32efaa36
...
...
@@ -340,7 +340,6 @@ class TableList extends React.Component<TableListProps, TableListState> {
title
:
'
Operation
'
,
dataIndex
:
'
operation
'
,
key
:
'
operation
'
,
width
:
120
,
render
:
(
text
:
string
,
record
:
TableRecord
)
=>
{
let
trialStatus
=
record
.
status
;
const
flag
:
boolean
=
(
trialStatus
===
'
RUNNING
'
)
?
false
:
true
;
...
...
@@ -413,7 +412,6 @@ class TableList extends React.Component<TableListProps, TableListState> {
title
:
realItem
,
dataIndex
:
item
,
key
:
item
,
width
:
'
6%
'
,
render
:
(
text
:
string
,
record
:
TableRecord
)
=>
{
const
eachTrial
=
TRIALS
.
getTrial
(
record
.
id
);
return
(
...
...
@@ -514,7 +512,6 @@ class TableList extends React.Component<TableListProps, TableListState> {
const
SequenceIdColumnConfig
:
ColumnProps
<
TableRecord
>
=
{
title
:
'
Trial No.
'
,
dataIndex
:
'
sequenceId
'
,
width
:
120
,
className
:
'
tableHead
'
,
sorter
:
(
a
,
b
)
=>
a
.
sequenceId
-
b
.
sequenceId
};
...
...
@@ -522,7 +519,6 @@ const SequenceIdColumnConfig: ColumnProps<TableRecord> = {
const
IdColumnConfig
:
ColumnProps
<
TableRecord
>
=
{
title
:
'
ID
'
,
dataIndex
:
'
id
'
,
width
:
60
,
className
:
'
tableHead leftTitle
'
,
sorter
:
(
a
,
b
)
=>
a
.
id
.
localeCompare
(
b
.
id
),
render
:
(
text
,
record
)
=>
(
...
...
@@ -533,7 +529,7 @@ const IdColumnConfig: ColumnProps<TableRecord> = {
const
StartTimeColumnConfig
:
ColumnProps
<
TableRecord
>
=
{
title
:
'
Start Time
'
,
dataIndex
:
'
startTime
'
,
width
:
160
,
sorter
:
(
a
,
b
)
=>
a
.
startTime
-
b
.
startTime
,
render
:
(
text
,
record
)
=>
(
<
span
>
{
formatTimestamp
(
record
.
startTime
)
}
</
span
>
)
...
...
@@ -542,7 +538,15 @@ const StartTimeColumnConfig: ColumnProps<TableRecord> = {
const
EndTimeColumnConfig
:
ColumnProps
<
TableRecord
>
=
{
title
:
'
End Time
'
,
dataIndex
:
'
endTime
'
,
width
:
160
,
sorter
:
(
a
,
b
,
sortOrder
)
=>
{
if
(
a
.
endTime
===
undefined
)
{
return
sortOrder
===
'
ascend
'
?
1
:
-
1
;
}
else
if
(
b
.
endTime
===
undefined
)
{
return
sortOrder
===
'
ascend
'
?
-
1
:
1
;
}
else
{
return
a
.
endTime
-
b
.
endTime
;
}
},
render
:
(
text
,
record
)
=>
(
<
span
>
{
formatTimestamp
(
record
.
endTime
,
'
--
'
)
}
</
span
>
)
...
...
@@ -551,17 +555,15 @@ const EndTimeColumnConfig: ColumnProps<TableRecord> = {
const
DurationColumnConfig
:
ColumnProps
<
TableRecord
>
=
{
title
:
'
Duration
'
,
dataIndex
:
'
duration
'
,
width
:
100
,
sorter
:
(
a
,
b
)
=>
a
.
duration
-
b
.
duration
,
render
:
(
text
,
record
)
=>
(
<
div
className
=
"durationsty"
>
<
div
>
{
convertDuration
(
record
.
duration
)
}
</
div
></
div
>
<
span
className
=
"durationsty"
>
{
convertDuration
(
record
.
duration
)
}
</
span
>
)
};
const
StatusColumnConfig
:
ColumnProps
<
TableRecord
>
=
{
title
:
'
Status
'
,
dataIndex
:
'
status
'
,
width
:
150
,
className
:
'
tableStatus
'
,
render
:
(
text
,
record
)
=>
(
<
span
className
=
{
`
${
record
.
status
}
commonStyle`
}
>
{
record
.
status
}
</
span
>
...
...
@@ -574,7 +576,7 @@ const StatusColumnConfig: ColumnProps<TableRecord> = {
const
IntermediateCountColumnConfig
:
ColumnProps
<
TableRecord
>
=
{
title
:
'
Intermediate result
'
,
dataIndex
:
'
intermediateCount
'
,
width
:
86
,
sorter
:
(
a
,
b
)
=>
a
.
intermediateCount
-
b
.
intermediateCount
,
render
:
(
text
,
record
)
=>
(
<
span
>
{
`#
${
record
.
intermediateCount
}
`
}
</
span
>
)
...
...
@@ -584,7 +586,6 @@ const AccuracyColumnConfig: ColumnProps<TableRecord> = {
title
:
'
Default metric
'
,
className
:
'
leftTitle
'
,
dataIndex
:
'
accuracy
'
,
width
:
120
,
sorter
:
(
a
,
b
,
sortOrder
)
=>
{
if
(
a
.
latestAccuracy
===
undefined
)
{
return
sortOrder
===
'
ascend
'
?
1
:
-
1
;
...
...
src/webui/src/static/style/table.scss
View file @
32efaa36
...
...
@@ -57,26 +57,24 @@
}
td
{
padding
:
0px
;
padding
:
0
15
px
;
line-height
:
24px
;
}
/* + button */
.ant-table-row-expand-icon
{
background
:
none
;
}
.ant-table-row-expand-icon-cell
{
background
:
#ccc
;
width
:
50px
;
.ant-table-row-expand-icon
{
background
:
none
;
border
:
none
;
width
:
100%
;
height
:
100%
;
}
}
.ant-table-row-expand-icon-cell
:hover
{
background
:
#ccc
;
}
.ant-table-selection-column
{
width
:
50px
;
}
}
/* let openrow content left*/
...
...
src/webui/yarn.lock
View file @
32efaa36
...
...
@@ -5975,9 +5975,9 @@ send@0.17.1:
range-parser "~1.2.1"
statuses "~1.5.0"
serialize-javascript@^1.7.0:
version "
1.7.0
"
resolved "https://registry.yarnpkg.com/serialize-javascript/-/serialize-javascript-
1.7.0.tgz#d6e0dfb2a3832a8c94468e6eb1db97e55a192a65
"
serialize-javascript@^1.7.0
, serialize-javascript@^2.1.1
:
version "
2.1.2
"
resolved "https://registry.yarnpkg.com/serialize-javascript/-/serialize-javascript-
2.1.2.tgz#ecec53b0e0317bdc95ef76ab7074b7384785fa61
"
serve-index@^1.9.1:
version "1.9.1"
...
...
test/metrics_test/trial.py
View file @
32efaa36
...
...
@@ -6,6 +6,7 @@ import nni
if
__name__
==
'__main__'
:
nni
.
get_next_parameter
()
time
.
sleep
(
1
)
for
i
in
range
(
10
):
if
i
%
2
==
0
:
print
(
'report intermediate result without end of line.'
,
end
=
''
)
...
...
test/naive_test/naive_trial.py
View file @
32efaa36
...
...
@@ -9,6 +9,7 @@ params = nni.get_next_parameter()
print
(
'params:'
,
params
)
x
=
params
[
'x'
]
time
.
sleep
(
1
)
for
i
in
range
(
1
,
10
):
nni
.
report_intermediate_result
(
x
**
i
)
time
.
sleep
(
0.5
)
...
...
test/utils.py
View file @
32efaa36
...
...
@@ -83,7 +83,7 @@ def is_experiment_done(nnimanager_log_path):
with
open
(
nnimanager_log_path
,
'r'
)
as
f
:
log_content
=
f
.
read
()
return
EXPERIMENT_DONE_SIGNAL
in
log_content
def
get_experiment_status
(
status_url
):
...
...
tools/bash-completion
View file @
32efaa36
# list of commands/arguments
__nnictl_cmds="create resume update stop trial experiment platform import export webui config log package tensorboard top"
__nnictl_cmds="create resume
view
update stop trial experiment platform import export webui config log package tensorboard top"
__nnictl_create_cmds="--config --port --debug"
__nnictl_resume_cmds="--port --debug"
__nnictl_view_cmds="--port"
__nnictl_update_cmds="searchspace concurrency duration trialnum"
__nnictl_update_searchspace_cmds="--filename"
__nnictl_update_concurrency_cmds="--value"
...
...
@@ -31,7 +32,7 @@ __nnictl_tensorboard_start_cmds="--trial_id --port"
__nnictl_top_cmds="--time"
# list of commands that accept an experiment ID as second argument
__nnictl_2
st
_expid_cmds=" resume stop import export "
__nnictl_2
nd
_expid_cmds=" resume
view
stop import export "
# list of commands that accept an experiment ID as third argument
__nnictl_3rd_expid_cmds=" update trial experiment webui config log tensorboard "
...
...
@@ -73,7 +74,7 @@ _nnictl()
COMPREPLY=($(compgen -W "${!args}" -- "${COMP_WORDS[2]}"))
# add experiment IDs to candidates if desired
if [[
" resume stop import export "
=~ " ${COMP_WORDS[1]} " ]]; then
if [[
$__nnictl_2nd_expid_cmds
=~ " ${COMP_WORDS[1]} " ]]; then
local experiments=$(ls ~/nni/experiments 2>/dev/null)
COMPREPLY+=($(compgen -W "$experiments" -- $cur))
fi
...
...
@@ -138,4 +139,8 @@ _nnictl()
fi
}
complete -o nosort -F _nnictl nnictl
if [[ ${BASH_VERSINFO[0]} -le 4 && ${BASH_VERSINFO[1]} -le 4 ]]; then
complete -F _nnictl nnictl
else
complete -o nosort -F _nnictl nnictl
fi
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment