Unverified Commit f002fcd5 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Compressor updates (#2136)

parent c3cd9fe7
...@@ -142,10 +142,10 @@ def create_model(model_name='naive'): ...@@ -142,10 +142,10 @@ def create_model(model_name='naive'):
else: else:
return VGG(19) return VGG(19)
def create_pruner(model, pruner_name): def create_pruner(model, pruner_name, optimizer=None):
pruner_class = prune_config[pruner_name]['pruner_class'] pruner_class = prune_config[pruner_name]['pruner_class']
config_list = prune_config[pruner_name]['config_list'] config_list = prune_config[pruner_name]['config_list']
return pruner_class(model, config_list) return pruner_class(model, config_list, optimizer)
def train(model, device, train_loader, optimizer): def train(model, device, train_loader, optimizer):
model.train() model.train()
...@@ -179,6 +179,7 @@ def test(model, device, test_loader): ...@@ -179,6 +179,7 @@ def test(model, device, test_loader):
def main(args): def main(args):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
os.makedirs(args.checkpoints_dir, exist_ok=True)
model_name = prune_config[args.pruner_name]['model_name'] model_name = prune_config[args.pruner_name]['model_name']
dataset_name = prune_config[args.pruner_name]['dataset_name'] dataset_name = prune_config[args.pruner_name]['dataset_name']
...@@ -203,8 +204,6 @@ def main(args): ...@@ -203,8 +204,6 @@ def main(args):
print('start model pruning...') print('start model pruning...')
if not os.path.exists(args.checkpoints_dir):
os.makedirs(args.checkpoints_dir)
model_path = os.path.join(args.checkpoints_dir, 'pruned_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name)) model_path = os.path.join(args.checkpoints_dir, 'pruned_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name))
mask_path = os.path.join(args.checkpoints_dir, 'mask_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name)) mask_path = os.path.join(args.checkpoints_dir, 'mask_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name))
......
...@@ -15,6 +15,13 @@ import torch.nn.functional as F ...@@ -15,6 +15,13 @@ import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torchvision import datasets, transforms from torchvision import datasets, transforms
# Temporary patch this example until the MNIST dataset download issue get resolved
# https://github.com/pytorch/vision/issues/1938
import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)
logger = logging.getLogger('mnist_AutoML') logger = logging.getLogger('mnist_AutoML')
......
...@@ -16,7 +16,7 @@ class ActivationRankFilterPruner(Pruner): ...@@ -16,7 +16,7 @@ class ActivationRankFilterPruner(Pruner):
to achieve a preset level of network sparsity. to achieve a preset level of network sparsity.
""" """
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1): def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
""" """
Parameters Parameters
---------- ----------
...@@ -25,6 +25,8 @@ class ActivationRankFilterPruner(Pruner): ...@@ -25,6 +25,8 @@ class ActivationRankFilterPruner(Pruner):
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str activation : str
Activation function Activation function
statistics_batch_num : int statistics_batch_num : int
...@@ -105,7 +107,7 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner): ...@@ -105,7 +107,7 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1607.03250 https://arxiv.org/abs/1607.03250
""" """
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1): def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
""" """
Parameters Parameters
---------- ----------
...@@ -114,6 +116,8 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner): ...@@ -114,6 +116,8 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str activation : str
Activation function Activation function
statistics_batch_num : int statistics_batch_num : int
...@@ -177,7 +181,7 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner): ...@@ -177,7 +181,7 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1611.06440 https://arxiv.org/abs/1611.06440
""" """
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1): def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
""" """
Parameters Parameters
---------- ----------
...@@ -186,6 +190,8 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner): ...@@ -186,6 +190,8 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str activation : str
Activation function Activation function
statistics_batch_num : int statistics_batch_num : int
......
...@@ -27,7 +27,7 @@ class Compressor: ...@@ -27,7 +27,7 @@ class Compressor:
Abstract base PyTorch compressor Abstract base PyTorch compressor
""" """
def __init__(self, model, config_list, optimizer): def __init__(self, model, config_list, optimizer=None):
""" """
Record necessary info in class members Record necessary info in class members
...@@ -235,7 +235,8 @@ class Compressor: ...@@ -235,7 +235,8 @@ class Compressor:
task() task()
return output return output
return new_step return new_step
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer) if self.optimizer is not None:
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)
class PrunerModuleWrapper(torch.nn.Module): class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner): def __init__(self, module, module_name, module_type, config, pruner):
...@@ -290,9 +291,10 @@ class Pruner(Compressor): ...@@ -290,9 +291,10 @@ class Pruner(Compressor):
""" """
def __init__(self, model, config_list, optimizer): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.patch_optimizer(self.update_mask) if optimizer is not None:
self.patch_optimizer(self.update_mask)
def compress(self): def compress(self):
self.update_mask() self.update_mask()
......
...@@ -16,7 +16,7 @@ class LevelPruner(Pruner): ...@@ -16,7 +16,7 @@ class LevelPruner(Pruner):
Prune to an exact pruning level specification Prune to an exact pruning level specification
""" """
def __init__(self, model, config_list, optimizer): def __init__(self, model, config_list, optimizer=None):
""" """
Parameters Parameters
---------- ----------
...@@ -24,6 +24,8 @@ class LevelPruner(Pruner): ...@@ -24,6 +24,8 @@ class LevelPruner(Pruner):
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
...@@ -78,9 +80,13 @@ class AGP_Pruner(Pruner): ...@@ -78,9 +80,13 @@ class AGP_Pruner(Pruner):
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self.now_epoch = 0 self.now_epoch = 0
self.set_wrappers_attribute("if_calculated", False) self.set_wrappers_attribute("if_calculated", False)
...@@ -176,13 +182,17 @@ class SlimPruner(Pruner): ...@@ -176,13 +182,17 @@ class SlimPruner(Pruner):
https://arxiv.org/pdf/1708.06519.pdf https://arxiv.org/pdf/1708.06519.pdf
""" """
def __init__(self, model, config_list, optimizer): def __init__(self, model, config_list, optimizer=None):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module
Model to be pruned
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
...@@ -244,7 +254,7 @@ class LotteryTicketPruner(Pruner): ...@@ -244,7 +254,7 @@ class LotteryTicketPruner(Pruner):
5. Repeat step 2, 3, and 4. 5. Repeat step 2, 3, and 4.
""" """
def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True): def __init__(self, model, config_list, optimizer=None, lr_scheduler=None, reset_weights=True):
""" """
Parameters Parameters
---------- ----------
......
...@@ -15,7 +15,7 @@ class WeightRankFilterPruner(Pruner): ...@@ -15,7 +15,7 @@ class WeightRankFilterPruner(Pruner):
importance criterion in convolution layers to achieve a preset level of network sparsity. importance criterion in convolution layers to achieve a preset level of network sparsity.
""" """
def __init__(self, model, config_list, optimizer): def __init__(self, model, config_list, optimizer=None):
""" """
Parameters Parameters
---------- ----------
...@@ -24,6 +24,8 @@ class WeightRankFilterPruner(Pruner): ...@@ -24,6 +24,8 @@ class WeightRankFilterPruner(Pruner):
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
...@@ -83,7 +85,7 @@ class L1FilterPruner(WeightRankFilterPruner): ...@@ -83,7 +85,7 @@ class L1FilterPruner(WeightRankFilterPruner):
https://arxiv.org/abs/1608.08710 https://arxiv.org/abs/1608.08710
""" """
def __init__(self, model, config_list, optimizer): def __init__(self, model, config_list, optimizer=None):
""" """
Parameters Parameters
---------- ----------
...@@ -92,6 +94,8 @@ class L1FilterPruner(WeightRankFilterPruner): ...@@ -92,6 +94,8 @@ class L1FilterPruner(WeightRankFilterPruner):
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
...@@ -131,7 +135,7 @@ class L2FilterPruner(WeightRankFilterPruner): ...@@ -131,7 +135,7 @@ class L2FilterPruner(WeightRankFilterPruner):
smallest L2 norm of the weights. smallest L2 norm of the weights.
""" """
def __init__(self, model, config_list, optimizer): def __init__(self, model, config_list, optimizer=None):
""" """
Parameters Parameters
---------- ----------
...@@ -140,6 +144,8 @@ class L2FilterPruner(WeightRankFilterPruner): ...@@ -140,6 +144,8 @@ class L2FilterPruner(WeightRankFilterPruner):
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
...@@ -187,8 +193,11 @@ class FPGMPruner(WeightRankFilterPruner): ...@@ -187,8 +193,11 @@ class FPGMPruner(WeightRankFilterPruner):
config_list: list config_list: list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "FPGM pruner is an iterative pruner, please pass optimizer of the model to it"
def get_mask(self, base_mask, weight, num_prune): def get_mask(self, base_mask, weight, num_prune):
""" """
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import logging import logging
import datetime import datetime
from nni.assessor import Assessor, AssessResult from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
from .model_factory import CurveModel from .model_factory import CurveModel
logger = logging.getLogger('curvefitting_Assessor') logger = logging.getLogger('curvefitting_Assessor')
...@@ -91,10 +92,11 @@ class CurvefittingAssessor(Assessor): ...@@ -91,10 +92,11 @@ class CurvefittingAssessor(Assessor):
Exception Exception
unrecognize exception in curvefitting_assessor unrecognize exception in curvefitting_assessor
""" """
self.trial_history = trial_history scalar_trial_history = extract_scalar_history(trial_history)
self.trial_history = scalar_trial_history
if not self.set_best_performance: if not self.set_best_performance:
return AssessResult.Good return AssessResult.Good
curr_step = len(trial_history) curr_step = len(scalar_trial_history)
if curr_step < self.start_step: if curr_step < self.start_step:
return AssessResult.Good return AssessResult.Good
...@@ -106,7 +108,7 @@ class CurvefittingAssessor(Assessor): ...@@ -106,7 +108,7 @@ class CurvefittingAssessor(Assessor):
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
# Predict the final result # Predict the final result
curvemodel = CurveModel(self.target_pos) curvemodel = CurveModel(self.target_pos)
predict_y = curvemodel.predict(trial_history) predict_y = curvemodel.predict(scalar_trial_history)
logger.info('Prediction done. Trial job id = %s. Predict value = %s', trial_job_id, predict_y) logger.info('Prediction done. Trial job id = %s. Predict value = %s', trial_job_id, predict_y)
if predict_y is None: if predict_y is None:
logger.info('wait for more information to predict precisely') logger.info('wait for more information to predict precisely')
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import logging import logging
from nni.assessor import Assessor, AssessResult from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
logger = logging.getLogger('medianstop_Assessor') logger = logging.getLogger('medianstop_Assessor')
...@@ -91,20 +92,12 @@ class MedianstopAssessor(Assessor): ...@@ -91,20 +92,12 @@ class MedianstopAssessor(Assessor):
if curr_step < self._start_step: if curr_step < self._start_step:
return AssessResult.Good return AssessResult.Good
try: scalar_trial_history = extract_scalar_history(trial_history)
num_trial_history = [float(ele) for ele in trial_history] self._update_data(trial_job_id, scalar_trial_history)
except (TypeError, ValueError) as error:
logger.warning('incorrect data type or value:')
logger.exception(error)
except Exception as error:
logger.warning('unrecognized exception in medianstop_assessor:')
logger.exception(error)
self._update_data(trial_job_id, num_trial_history)
if self._high_better: if self._high_better:
best_history = max(trial_history) best_history = max(scalar_trial_history)
else: else:
best_history = min(trial_history) best_history = min(scalar_trial_history)
avg_array = [] avg_array = []
for id_ in self._completed_avg_history: for id_ in self._completed_avg_history:
......
...@@ -234,4 +234,5 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -234,4 +234,5 @@ class MsgDispatcher(MsgDispatcherBase):
if multi_thread_enabled(): if multi_thread_enabled():
self._handle_final_metric_data(data) self._handle_final_metric_data(data)
else: else:
data['value'] = to_json(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data) self.enqueue_command(CommandType.ReportMetricData, data)
...@@ -62,6 +62,13 @@ def extract_scalar_reward(value, scalar_key='default'): ...@@ -62,6 +62,13 @@ def extract_scalar_reward(value, scalar_key='default'):
""" """
Extract scalar reward from trial result. Extract scalar reward from trial result.
Parameters
----------
value : int, float, dict
the reported final metric data
scalar_key : str
the key name that indicates the numeric number
Raises Raises
------ ------
RuntimeError RuntimeError
...@@ -78,6 +85,26 @@ def extract_scalar_reward(value, scalar_key='default'): ...@@ -78,6 +85,26 @@ def extract_scalar_reward(value, scalar_key='default'):
return reward return reward
def extract_scalar_history(trial_history, scalar_key='default'):
"""
Extract scalar value from a list of intermediate results.
Parameters
----------
trial_history : list
accumulated intermediate results of a trial
scalar_key : str
the key name that indicates the numeric number
Raises
------
RuntimeError
Incorrect final result: the final result should be float/int,
or a dict which has a key named "default" whose value is float/int.
"""
return [extract_scalar_reward(ele, scalar_key) for ele in trial_history]
def convert_dict2tuple(value): def convert_dict2tuple(value):
""" """
convert dict type to tuple to solve unhashable problem. convert dict type to tuple to solve unhashable problem.
...@@ -90,7 +117,9 @@ def convert_dict2tuple(value): ...@@ -90,7 +117,9 @@ def convert_dict2tuple(value):
def init_dispatcher_logger(): def init_dispatcher_logger():
""" Initialize dispatcher logging configuration""" """
Initialize dispatcher logging configuration
"""
logger_file_path = 'dispatcher.log' logger_file_path = 'dispatcher.log'
if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None: if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None:
logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path) logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path)
......
...@@ -131,9 +131,8 @@ class CompressorTestCase(TestCase): ...@@ -131,9 +131,8 @@ class CompressorTestCase(TestCase):
""" """
model = TorchModel() model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}] config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
pruner = torch_compressor.FPGMPruner(model, config_list, optimizer) pruner = torch_compressor.FPGMPruner(model, config_list, torch.optim.SGD(model.parameters(), lr=0.01))
model.conv2.module.weight.data = torch.tensor(w).float() model.conv2.module.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(model.conv2) masks = pruner.calc_mask(model.conv2)
...@@ -176,10 +175,9 @@ class CompressorTestCase(TestCase): ...@@ -176,10 +175,9 @@ class CompressorTestCase(TestCase):
w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2, w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2,
np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4]) np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4])
model = TorchModel() model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']}, config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
{'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}] {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
pruner = torch_compressor.L1FilterPruner(model, config_list, optimizer) pruner = torch_compressor.L1FilterPruner(model, config_list)
model.conv1.module.weight.data = torch.tensor(w).float() model.conv1.module.weight.data = torch.tensor(w).float()
model.conv2.module.weight.data = torch.tensor(w).float() model.conv2.module.weight.data = torch.tensor(w).float()
...@@ -204,11 +202,10 @@ class CompressorTestCase(TestCase): ...@@ -204,11 +202,10 @@ class CompressorTestCase(TestCase):
""" """
w = np.array([0, 1, 2, 3, 4]) w = np.array([0, 1, 2, 3, 4])
model = TorchModel() model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}] config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float() model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(-w).float() model.bn2.weight.data = torch.tensor(-w).float()
pruner = torch_compressor.SlimPruner(model, config_list, optimizer) pruner = torch_compressor.SlimPruner(model, config_list)
mask1 = pruner.calc_mask(model.bn1) mask1 = pruner.calc_mask(model.bn1)
mask2 = pruner.calc_mask(model.bn2) mask2 = pruner.calc_mask(model.bn2)
...@@ -218,11 +215,10 @@ class CompressorTestCase(TestCase): ...@@ -218,11 +215,10 @@ class CompressorTestCase(TestCase):
assert all(mask2['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask2['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
model = TorchModel() model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}] config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float() model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(w).float() model.bn2.weight.data = torch.tensor(w).float()
pruner = torch_compressor.SlimPruner(model, config_list, optimizer) pruner = torch_compressor.SlimPruner(model, config_list)
mask1 = pruner.calc_mask(model.bn1) mask1 = pruner.calc_mask(model.bn1)
mask2 = pruner.calc_mask(model.bn2) mask2 = pruner.calc_mask(model.bn2)
......
...@@ -247,7 +247,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -247,7 +247,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
} }
// intermediateArr just store default val // intermediateArr just store default val
Object.keys(res.data).map(item => { Object.keys(res.data).map(item => {
if(res.data[item].type === 'PERIODICAL'){ if (res.data[item].type === 'PERIODICAL') {
const temp = parseMetrics(res.data[item].data); const temp = parseMetrics(res.data[item].data);
if (typeof temp === 'object') { if (typeof temp === 'object') {
intermediateArr.push(temp[intermediateKey]); intermediateArr.push(temp[intermediateKey]);
...@@ -278,11 +278,13 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -278,11 +278,13 @@ class TableList extends React.Component<TableListProps, TableListState> {
// just watch default key-val // just watch default key-val
if (isShowDefault === true) { if (isShowDefault === true) {
Object.keys(intermediateData).map(item => { Object.keys(intermediateData).map(item => {
const temp = parseMetrics(intermediateData[item].data); if (intermediateData[item].type === 'PERIODICAL') {
if (typeof temp === 'object') { const temp = parseMetrics(intermediateData[item].data);
intermediateArr.push(temp[value]); if (typeof temp === 'object') {
} else { intermediateArr.push(temp[value]);
intermediateArr.push(temp); } else {
intermediateArr.push(temp);
}
} }
}); });
} else { } else {
......
...@@ -36,11 +36,11 @@ def update_training_service_config(args): ...@@ -36,11 +36,11 @@ def update_training_service_config(args):
config[args.ts]['paiConfig']['token'] = args.pai_token config[args.ts]['paiConfig']['token'] = args.pai_token
if args.nni_docker_image is not None: if args.nni_docker_image is not None:
config[args.ts]['trial']['image'] = args.nni_docker_image config[args.ts]['trial']['image'] = args.nni_docker_image
if args.nniManagerNFSMountPath is not None: if args.nni_manager_nfs_mount_path is not None:
config[args.ts]['trial']['nniManagerNFSMountPath'] = args.nni_manager_nfs_mount_path config[args.ts]['trial']['nniManagerNFSMountPath'] = args.nni_manager_nfs_mount_path
if args.containerNFSMountPath is not None: if args.container_nfs_mount_path is not None:
config[args.ts]['trial']['containerNFSMountPath'] = args.container_nfs_mount_path config[args.ts]['trial']['containerNFSMountPath'] = args.container_nfs_mount_path
if args.paiStoragePlugin is not None: if args.pai_storage_plugin is not None:
config[args.ts]['trial']['paiStoragePlugin'] = args.pai_storage_plugin config[args.ts]['trial']['paiStoragePlugin'] = args.pai_storage_plugin
if args.vc is not None: if args.vc is not None:
config[args.ts]['trial']['virtualCluster'] = args.vc config[args.ts]['trial']['virtualCluster'] = args.vc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment