Unverified Commit 32efaa36 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #219 from microsoft/master

merge master
parents cd3a912a 97b258b0
...@@ -21,7 +21,7 @@ import { AzureStorageClientUtility } from '../azureStorageClientUtils'; ...@@ -21,7 +21,7 @@ import { AzureStorageClientUtility } from '../azureStorageClientUtils';
import { NFSConfig } from '../kubernetesConfig'; import { NFSConfig } from '../kubernetesConfig';
import { KubernetesTrialJobDetail } from '../kubernetesData'; import { KubernetesTrialJobDetail } from '../kubernetesData';
import { KubernetesTrainingService } from '../kubernetesTrainingService'; import { KubernetesTrainingService } from '../kubernetesTrainingService';
import { KubeflowOperatorClient } from './kubeflowApiClient'; import { KubeflowOperatorClientFactory } from './kubeflowApiClient';
import { KubeflowClusterConfig, KubeflowClusterConfigAzure, KubeflowClusterConfigFactory, KubeflowClusterConfigNFS, import { KubeflowClusterConfig, KubeflowClusterConfigAzure, KubeflowClusterConfigFactory, KubeflowClusterConfigNFS,
KubeflowTrialConfig, KubeflowTrialConfigFactory, KubeflowTrialConfigPytorch, KubeflowTrialConfigTensorflow KubeflowTrialConfig, KubeflowTrialConfigFactory, KubeflowTrialConfigPytorch, KubeflowTrialConfigTensorflow
} from './kubeflowConfig'; } from './kubeflowConfig';
...@@ -136,8 +136,8 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -136,8 +136,8 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
nfsKubeflowClusterConfig.nfs.path nfsKubeflowClusterConfig.nfs.path
); );
} }
this.kubernetesCRDClient = KubeflowOperatorClient.generateOperatorClient(this.kubeflowClusterConfig.operator, this.kubernetesCRDClient = KubeflowOperatorClientFactory.createClient(
this.kubeflowClusterConfig.apiVersion); this.kubeflowClusterConfig.operator, this.kubeflowClusterConfig.apiVersion);
break; break;
case TrialConfigMetadataKey.TRIAL_CONFIG: case TrialConfigMetadataKey.TRIAL_CONFIG:
......
{
"defaultSeverity": "error",
"extends": "tslint-microsoft-contrib",
"jsRules": {},
"rules": {
"no-relative-imports": false,
"export-name": false,
"interface-name": [true, "never-prefix"],
"no-increment-decrement": false,
"promise-function-async": false,
"no-console": [true, "log"],
"no-multiline-string": false,
"no-suspicious-comment": false,
"no-backbone-get-set-outside-model": false,
"max-classes-per-file": false
},
"rulesDirectory": [],
"linterOptions": {
"exclude": [
"training_service/test/*",
"rest_server/test/*",
"core/test/*"
]
}
}
\ No newline at end of file
This diff is collapsed.
sklearn scikit-learn==0.20
\ No newline at end of file \ No newline at end of file
...@@ -18,10 +18,11 @@ class DartsTrainer(Trainer): ...@@ -18,10 +18,11 @@ class DartsTrainer(Trainer):
def __init__(self, model, loss, metrics, def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid, optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None, arc_learning_rate=3.0E-4, unrolled=True): callbacks=None, arc_learning_rate=3.0E-4, unrolled=False):
super().__init__(model, mutator if mutator is not None else DartsMutator(model), super().__init__(model, mutator if mutator is not None else DartsMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks) batch_size, workers, device, log_frequency, callbacks)
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999), self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3) weight_decay=1.0E-3)
self.unrolled = unrolled self.unrolled = unrolled
......
...@@ -111,7 +111,7 @@ class Mutator(BaseMutator): ...@@ -111,7 +111,7 @@ class Mutator(BaseMutator):
if "BoolTensor" in mask.type(): if "BoolTensor" in mask.type():
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type(): elif "FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)] out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m]
else: else:
raise ValueError("Unrecognized mask") raise ValueError("Unrecognized mask")
return out return out
......
...@@ -4,13 +4,18 @@ ...@@ -4,13 +4,18 @@
import copy import copy
import numpy as np import numpy as np
import torch.nn.functional as F import torch
from torch import nn
from nni.nas.pytorch.darts import DartsMutator from nni.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.mutables import LayerChoice from nni.nas.pytorch.mutables import LayerChoice
class PdartsMutator(DartsMutator): class PdartsMutator(DartsMutator):
"""
It works with PdartsTrainer to calculate ops weights,
and drop weights in different PDARTS epochs.
"""
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}): def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}):
self.pdarts_epoch_index = pdarts_epoch_index self.pdarts_epoch_index = pdarts_epoch_index
...@@ -22,60 +27,66 @@ class PdartsMutator(DartsMutator): ...@@ -22,60 +27,66 @@ class PdartsMutator(DartsMutator):
super(PdartsMutator, self).__init__(model) super(PdartsMutator, self).__init__(model)
# this loop go through mutables with different keys,
# it's mainly to update length of choices.
for mutable in self.mutables: for mutable in self.mutables:
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
switches = self.switches.get(mutable.key, [True for j in range(mutable.length)]) switches = self.switches.get(mutable.key, [True for j in range(mutable.length)])
choices = self.choices[mutable.key]
operations_count = np.sum(switches)
# +1 and -1 are caused by zero operation in darts network
# the zero operation is not in choices list in network, but its weight are in,
# so it needs one more weights and switch for zero.
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(operations_count + 1))
self.switches[mutable.key] = switches
# update LayerChoice instances in model,
# it's physically remove dropped choices operations.
for module in self.model.modules():
if isinstance(module, LayerChoice):
switches = self.switches.get(module.key)
choices = self.choices[module.key]
if len(module.choices) > len(choices):
# from last to first, so that it won't effect previous indexes after removed one.
for index in range(len(switches)-1, -1, -1): for index in range(len(switches)-1, -1, -1):
if switches[index] == False: if switches[index] == False:
del(mutable.choices[index]) del(module.choices[index])
mutable.length -= 1 module.length -= 1
self.switches[mutable.key] = switches def sample_final(self):
results = super().sample_final()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
# As some operations are dropped physically,
# so it needs to fill back false to track dropped operations.
trained_result = results[mutable.key]
trained_index = 0
switches = self.switches[mutable.key]
result = torch.Tensor(switches).bool()
for index in range(len(result)):
if result[index]:
result[index] = trained_result[trained_index]
trained_index += 1
results[mutable.key] = result
return results
def drop_paths(self): def drop_paths(self):
for key in self.switches: """
prob = F.softmax(self.choices[key], dim=-1).data.cpu().numpy() This method is called when a PDARTS epoch is finished.
It prepares switches for next epoch.
switches = self.switches[key] candidate operations with False switch will be doppped in next epoch.
"""
all_switches = copy.deepcopy(self.switches)
for key in all_switches:
switches = all_switches[key]
idxs = [] idxs = []
for j in range(len(switches)): for j in range(len(switches)):
if switches[j]: if switches[j]:
idxs.append(j) idxs.append(j)
if self.pdarts_epoch_index == len(self.pdarts_num_to_drop) - 1: sorted_weights = self.choices[key].data.cpu().numpy()[:-1]
# for the last stage, drop all Zero operations drop = np.argsort(sorted_weights)[:self.pdarts_num_to_drop[self.pdarts_epoch_index]]
drop = self.get_min_k_no_zero(prob, idxs, self.pdarts_num_to_drop[self.pdarts_epoch_index])
else:
drop = self.get_min_k(prob, self.pdarts_num_to_drop[self.pdarts_epoch_index])
for idx in drop: for idx in drop:
switches[idxs[idx]] = False switches[idxs[idx]] = False
return self.switches return all_switches
def get_min_k(self, input_in, k):
index = []
for _ in range(k):
idx = np.argmin(input)
index.append(idx)
return index
def get_min_k_no_zero(self, w_in, idxs, k):
w = copy.deepcopy(w_in)
index = []
if 0 in idxs:
zf = True
else:
zf = False
if zf:
w = w[1:]
index.append(0)
k = k - 1
for _ in range(k):
idx = np.argmin(w)
w[idx] = 1
if zf:
idx = idx + 1
index.append(idx)
return index
...@@ -14,14 +14,22 @@ logger = logging.getLogger(__name__) ...@@ -14,14 +14,22 @@ logger = logging.getLogger(__name__)
class PdartsTrainer(BaseTrainer): class PdartsTrainer(BaseTrainer):
"""
def __init__(self, model_creator, layers, metrics, This trainer implements the PDARTS algorithm.
PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network.
This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows.
pdarts_num_layers means how many layers more than first epoch.
pdarts_num_to_drop means how many candidate operations should be dropped in each epoch.
So that the grew network can in similar size.
"""
def __init__(self, model_creator, init_layers, metrics,
num_epochs, dataset_train, dataset_valid, num_epochs, dataset_train, dataset_valid,
pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2], pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 1],
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None): mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, unrolled=False):
super(PdartsTrainer, self).__init__() super(PdartsTrainer, self).__init__()
self.model_creator = model_creator self.model_creator = model_creator
self.layers = layers self.init_layers = init_layers
self.pdarts_num_layers = pdarts_num_layers self.pdarts_num_layers = pdarts_num_layers
self.pdarts_num_to_drop = pdarts_num_to_drop self.pdarts_num_to_drop = pdarts_num_to_drop
self.pdarts_epoch = len(pdarts_num_to_drop) self.pdarts_epoch = len(pdarts_num_to_drop)
...@@ -33,16 +41,17 @@ class PdartsTrainer(BaseTrainer): ...@@ -33,16 +41,17 @@ class PdartsTrainer(BaseTrainer):
"batch_size": batch_size, "batch_size": batch_size,
"workers": workers, "workers": workers,
"device": device, "device": device,
"log_frequency": log_frequency "log_frequency": log_frequency,
"unrolled": unrolled
} }
self.callbacks = callbacks if callbacks is not None else [] self.callbacks = callbacks if callbacks is not None else []
def train(self): def train(self):
layers = self.layers
switches = None switches = None
for epoch in range(self.pdarts_epoch): for epoch in range(self.pdarts_epoch):
layers = self.layers+self.pdarts_num_layers[epoch] layers = self.init_layers+self.pdarts_num_layers[epoch]
model, criterion, optim, lr_scheduler = self.model_creator(layers) model, criterion, optim, lr_scheduler = self.model_creator(layers)
self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches) self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches)
...@@ -66,7 +75,7 @@ class PdartsTrainer(BaseTrainer): ...@@ -66,7 +75,7 @@ class PdartsTrainer(BaseTrainer):
callback.on_epoch_end(epoch) callback.on_epoch_end(epoch)
def validate(self): def validate(self):
self.model.validate() self.trainer.validate()
def export(self, file): def export(self, file):
mutator_export = self.mutator.export() mutator_export = self.mutator.export()
......
...@@ -7,4 +7,4 @@ scipy ...@@ -7,4 +7,4 @@ scipy
hyperopt==0.1.2 hyperopt==0.1.2
# metis tuner # metis tuner
sklearn scikit-learn==0.20
...@@ -66,7 +66,8 @@ ...@@ -66,7 +66,8 @@
}, },
"resolutions": { "resolutions": {
"@types/react": "16.4.17", "@types/react": "16.4.17",
"js-yaml": "^3.13.1" "js-yaml": "^3.13.1",
"serialize-javascript": "^2.1.1"
}, },
"babel": { "babel": {
"presets": [ "presets": [
......
...@@ -340,7 +340,6 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -340,7 +340,6 @@ class TableList extends React.Component<TableListProps, TableListState> {
title: 'Operation', title: 'Operation',
dataIndex: 'operation', dataIndex: 'operation',
key: 'operation', key: 'operation',
width: 120,
render: (text: string, record: TableRecord) => { render: (text: string, record: TableRecord) => {
let trialStatus = record.status; let trialStatus = record.status;
const flag: boolean = (trialStatus === 'RUNNING') ? false : true; const flag: boolean = (trialStatus === 'RUNNING') ? false : true;
...@@ -413,7 +412,6 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -413,7 +412,6 @@ class TableList extends React.Component<TableListProps, TableListState> {
title: realItem, title: realItem,
dataIndex: item, dataIndex: item,
key: item, key: item,
width: '6%',
render: (text: string, record: TableRecord) => { render: (text: string, record: TableRecord) => {
const eachTrial = TRIALS.getTrial(record.id); const eachTrial = TRIALS.getTrial(record.id);
return ( return (
...@@ -514,7 +512,6 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -514,7 +512,6 @@ class TableList extends React.Component<TableListProps, TableListState> {
const SequenceIdColumnConfig: ColumnProps<TableRecord> = { const SequenceIdColumnConfig: ColumnProps<TableRecord> = {
title: 'Trial No.', title: 'Trial No.',
dataIndex: 'sequenceId', dataIndex: 'sequenceId',
width: 120,
className: 'tableHead', className: 'tableHead',
sorter: (a, b) => a.sequenceId - b.sequenceId sorter: (a, b) => a.sequenceId - b.sequenceId
}; };
...@@ -522,7 +519,6 @@ const SequenceIdColumnConfig: ColumnProps<TableRecord> = { ...@@ -522,7 +519,6 @@ const SequenceIdColumnConfig: ColumnProps<TableRecord> = {
const IdColumnConfig: ColumnProps<TableRecord> = { const IdColumnConfig: ColumnProps<TableRecord> = {
title: 'ID', title: 'ID',
dataIndex: 'id', dataIndex: 'id',
width: 60,
className: 'tableHead leftTitle', className: 'tableHead leftTitle',
sorter: (a, b) => a.id.localeCompare(b.id), sorter: (a, b) => a.id.localeCompare(b.id),
render: (text, record) => ( render: (text, record) => (
...@@ -533,7 +529,7 @@ const IdColumnConfig: ColumnProps<TableRecord> = { ...@@ -533,7 +529,7 @@ const IdColumnConfig: ColumnProps<TableRecord> = {
const StartTimeColumnConfig: ColumnProps<TableRecord> = { const StartTimeColumnConfig: ColumnProps<TableRecord> = {
title: 'Start Time', title: 'Start Time',
dataIndex: 'startTime', dataIndex: 'startTime',
width: 160, sorter: (a, b) => a.startTime - b.startTime,
render: (text, record) => ( render: (text, record) => (
<span>{formatTimestamp(record.startTime)}</span> <span>{formatTimestamp(record.startTime)}</span>
) )
...@@ -542,7 +538,15 @@ const StartTimeColumnConfig: ColumnProps<TableRecord> = { ...@@ -542,7 +538,15 @@ const StartTimeColumnConfig: ColumnProps<TableRecord> = {
const EndTimeColumnConfig: ColumnProps<TableRecord> = { const EndTimeColumnConfig: ColumnProps<TableRecord> = {
title: 'End Time', title: 'End Time',
dataIndex: 'endTime', dataIndex: 'endTime',
width: 160, sorter: (a, b, sortOrder) => {
if (a.endTime === undefined) {
return sortOrder === 'ascend' ? 1 : -1;
} else if (b.endTime === undefined) {
return sortOrder === 'ascend' ? -1 : 1;
} else {
return a.endTime - b.endTime;
}
},
render: (text, record) => ( render: (text, record) => (
<span>{formatTimestamp(record.endTime, '--')}</span> <span>{formatTimestamp(record.endTime, '--')}</span>
) )
...@@ -551,17 +555,15 @@ const EndTimeColumnConfig: ColumnProps<TableRecord> = { ...@@ -551,17 +555,15 @@ const EndTimeColumnConfig: ColumnProps<TableRecord> = {
const DurationColumnConfig: ColumnProps<TableRecord> = { const DurationColumnConfig: ColumnProps<TableRecord> = {
title: 'Duration', title: 'Duration',
dataIndex: 'duration', dataIndex: 'duration',
width: 100,
sorter: (a, b) => a.duration - b.duration, sorter: (a, b) => a.duration - b.duration,
render: (text, record) => ( render: (text, record) => (
<div className="durationsty"><div>{convertDuration(record.duration)}</div></div> <span className="durationsty">{convertDuration(record.duration)}</span>
) )
}; };
const StatusColumnConfig: ColumnProps<TableRecord> = { const StatusColumnConfig: ColumnProps<TableRecord> = {
title: 'Status', title: 'Status',
dataIndex: 'status', dataIndex: 'status',
width: 150,
className: 'tableStatus', className: 'tableStatus',
render: (text, record) => ( render: (text, record) => (
<span className={`${record.status} commonStyle`}>{record.status}</span> <span className={`${record.status} commonStyle`}>{record.status}</span>
...@@ -574,7 +576,7 @@ const StatusColumnConfig: ColumnProps<TableRecord> = { ...@@ -574,7 +576,7 @@ const StatusColumnConfig: ColumnProps<TableRecord> = {
const IntermediateCountColumnConfig: ColumnProps<TableRecord> = { const IntermediateCountColumnConfig: ColumnProps<TableRecord> = {
title: 'Intermediate result', title: 'Intermediate result',
dataIndex: 'intermediateCount', dataIndex: 'intermediateCount',
width: 86, sorter: (a, b) => a.intermediateCount - b.intermediateCount,
render: (text, record) => ( render: (text, record) => (
<span>{`#${record.intermediateCount}`}</span> <span>{`#${record.intermediateCount}`}</span>
) )
...@@ -584,7 +586,6 @@ const AccuracyColumnConfig: ColumnProps<TableRecord> = { ...@@ -584,7 +586,6 @@ const AccuracyColumnConfig: ColumnProps<TableRecord> = {
title: 'Default metric', title: 'Default metric',
className: 'leftTitle', className: 'leftTitle',
dataIndex: 'accuracy', dataIndex: 'accuracy',
width: 120,
sorter: (a, b, sortOrder) => { sorter: (a, b, sortOrder) => {
if (a.latestAccuracy === undefined) { if (a.latestAccuracy === undefined) {
return sortOrder === 'ascend' ? 1 : -1; return sortOrder === 'ascend' ? 1 : -1;
......
...@@ -57,26 +57,24 @@ ...@@ -57,26 +57,24 @@
} }
td{ td{
padding: 0px; padding: 0 15px;
line-height: 24px; line-height: 24px;
} }
/* + button */
.ant-table-row-expand-icon{
background: none;
}
.ant-table-row-expand-icon-cell{ .ant-table-row-expand-icon-cell{
background: #ccc; background: #ccc;
width: 50px;
.ant-table-row-expand-icon{ .ant-table-row-expand-icon{
background: none;
border: none; border: none;
width: 100%;
height: 100%;
} }
} }
.ant-table-row-expand-icon-cell:hover{ .ant-table-row-expand-icon-cell:hover{
background: #ccc; background: #ccc;
} }
.ant-table-selection-column{
width: 50px;
}
} }
/* let openrow content left*/ /* let openrow content left*/
......
...@@ -5975,9 +5975,9 @@ send@0.17.1: ...@@ -5975,9 +5975,9 @@ send@0.17.1:
range-parser "~1.2.1" range-parser "~1.2.1"
statuses "~1.5.0" statuses "~1.5.0"
serialize-javascript@^1.7.0: serialize-javascript@^1.7.0, serialize-javascript@^2.1.1:
version "1.7.0" version "2.1.2"
resolved "https://registry.yarnpkg.com/serialize-javascript/-/serialize-javascript-1.7.0.tgz#d6e0dfb2a3832a8c94468e6eb1db97e55a192a65" resolved "https://registry.yarnpkg.com/serialize-javascript/-/serialize-javascript-2.1.2.tgz#ecec53b0e0317bdc95ef76ab7074b7384785fa61"
serve-index@^1.9.1: serve-index@^1.9.1:
version "1.9.1" version "1.9.1"
......
...@@ -6,6 +6,7 @@ import nni ...@@ -6,6 +6,7 @@ import nni
if __name__ == '__main__': if __name__ == '__main__':
nni.get_next_parameter() nni.get_next_parameter()
time.sleep(1)
for i in range(10): for i in range(10):
if i % 2 == 0: if i % 2 == 0:
print('report intermediate result without end of line.', end='') print('report intermediate result without end of line.', end='')
......
...@@ -9,6 +9,7 @@ params = nni.get_next_parameter() ...@@ -9,6 +9,7 @@ params = nni.get_next_parameter()
print('params:', params) print('params:', params)
x = params['x'] x = params['x']
time.sleep(1)
for i in range(1, 10): for i in range(1, 10):
nni.report_intermediate_result(x ** i) nni.report_intermediate_result(x ** i)
time.sleep(0.5) time.sleep(0.5)
......
# list of commands/arguments # list of commands/arguments
__nnictl_cmds="create resume update stop trial experiment platform import export webui config log package tensorboard top" __nnictl_cmds="create resume view update stop trial experiment platform import export webui config log package tensorboard top"
__nnictl_create_cmds="--config --port --debug" __nnictl_create_cmds="--config --port --debug"
__nnictl_resume_cmds="--port --debug" __nnictl_resume_cmds="--port --debug"
__nnictl_view_cmds="--port"
__nnictl_update_cmds="searchspace concurrency duration trialnum" __nnictl_update_cmds="searchspace concurrency duration trialnum"
__nnictl_update_searchspace_cmds="--filename" __nnictl_update_searchspace_cmds="--filename"
__nnictl_update_concurrency_cmds="--value" __nnictl_update_concurrency_cmds="--value"
...@@ -31,7 +32,7 @@ __nnictl_tensorboard_start_cmds="--trial_id --port" ...@@ -31,7 +32,7 @@ __nnictl_tensorboard_start_cmds="--trial_id --port"
__nnictl_top_cmds="--time" __nnictl_top_cmds="--time"
# list of commands that accept an experiment ID as second argument # list of commands that accept an experiment ID as second argument
__nnictl_2st_expid_cmds=" resume stop import export " __nnictl_2nd_expid_cmds=" resume view stop import export "
# list of commands that accept an experiment ID as third argument # list of commands that accept an experiment ID as third argument
__nnictl_3rd_expid_cmds=" update trial experiment webui config log tensorboard " __nnictl_3rd_expid_cmds=" update trial experiment webui config log tensorboard "
...@@ -73,7 +74,7 @@ _nnictl() ...@@ -73,7 +74,7 @@ _nnictl()
COMPREPLY=($(compgen -W "${!args}" -- "${COMP_WORDS[2]}")) COMPREPLY=($(compgen -W "${!args}" -- "${COMP_WORDS[2]}"))
# add experiment IDs to candidates if desired # add experiment IDs to candidates if desired
if [[ " resume stop import export " =~ " ${COMP_WORDS[1]} " ]]; then if [[ $__nnictl_2nd_expid_cmds =~ " ${COMP_WORDS[1]} " ]]; then
local experiments=$(ls ~/nni/experiments 2>/dev/null) local experiments=$(ls ~/nni/experiments 2>/dev/null)
COMPREPLY+=($(compgen -W "$experiments" -- $cur)) COMPREPLY+=($(compgen -W "$experiments" -- $cur))
fi fi
...@@ -138,4 +139,8 @@ _nnictl() ...@@ -138,4 +139,8 @@ _nnictl()
fi fi
} }
complete -o nosort -F _nnictl nnictl if [[ ${BASH_VERSINFO[0]} -le 4 && ${BASH_VERSINFO[1]} -le 4 ]]; then
complete -F _nnictl nnictl
else
complete -o nosort -F _nnictl nnictl
fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment