Commit ea665155 authored by quzha's avatar quzha
Browse files

Merge branch 'master' of github.com:Microsoft/nni into dev-nas-refactor

parents 73b2221b ae36373c
...@@ -126,9 +126,10 @@ def report_intermediate_result(metric): ...@@ -126,9 +126,10 @@ def report_intermediate_result(metric):
serializable object. serializable object.
""" """
global _intermediate_seq global _intermediate_seq
assert _params is not None, 'nni.get_next_parameter() needs to be called before report_intermediate_result' assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric = json_tricks.dumps({ metric = json_tricks.dumps({
'parameter_id': _params['parameter_id'], 'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL', 'type': 'PERIODICAL',
'sequence': _intermediate_seq, 'sequence': _intermediate_seq,
...@@ -147,9 +148,10 @@ def report_final_result(metric): ...@@ -147,9 +148,10 @@ def report_final_result(metric):
metric: metric:
serializable object. serializable object.
""" """
assert _params is not None, 'nni.get_next_parameter() needs to be called before report_final_result' assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_final_result'
metric = json_tricks.dumps({ metric = json_tricks.dumps({
'parameter_id': _params['parameter_id'], 'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL', 'type': 'FINAL',
'sequence': 0, 'sequence': 0,
......
...@@ -41,7 +41,7 @@ logging.basicConfig(level=logging.INFO) ...@@ -41,7 +41,7 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('test_tuner') logger = logging.getLogger('test_tuner')
class TunerTestCase(TestCase): class BuiltinTunersTestCase(TestCase):
""" """
Targeted at testing functions of built-in tuners, including Targeted at testing functions of built-in tuners, including
- [ ] load_checkpoint - [ ] load_checkpoint
......
from unittest import TestCase, main from unittest import TestCase, main
import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -7,11 +8,11 @@ import nni.compression.torch as torch_compressor ...@@ -7,11 +8,11 @@ import nni.compression.torch as torch_compressor
if tf.__version__ >= '2.0': if tf.__version__ >= '2.0':
import nni.compression.tensorflow as tf_compressor import nni.compression.tensorflow as tf_compressor
def get_tf_mnist_model(): def get_tf_model():
model = tf.keras.models.Sequential([ model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"), tf.keras.layers.Conv2D(filters=5, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
tf.keras.layers.MaxPooling2D(pool_size=2), tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation='relu', padding="SAME"), tf.keras.layers.Conv2D(filters=10, kernel_size=3, activation='relu', padding="SAME"),
tf.keras.layers.MaxPooling2D(pool_size=2), tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Flatten(), tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=128, activation='relu'), tf.keras.layers.Dense(units=128, activation='relu'),
...@@ -23,43 +24,51 @@ def get_tf_mnist_model(): ...@@ -23,43 +24,51 @@ def get_tf_mnist_model():
metrics=["accuracy"]) metrics=["accuracy"])
return model return model
class TorchMnist(torch.nn.Module): class TorchModel(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(500, 10) self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x): def forward(self, x):
x = F.relu(self.conv1(x)) x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x)) x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50) x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
def tf2(func): def tf2(func):
def test_tf2_func(self): def test_tf2_func(*args):
if tf.__version__ >= '2.0': if tf.__version__ >= '2.0':
func() func(*args)
return test_tf2_func return test_tf2_func
k1 = [[1]*3]*3
k2 = [[2]*3]*3
k3 = [[3]*3]*3
k4 = [[4]*3]*3
k5 = [[5]*3]*3
w = [[k1, k2, k3, k4, k5]] * 10
class CompressorTestCase(TestCase): class CompressorTestCase(TestCase):
def test_torch_pruner(self): def test_torch_level_pruner(self):
model = TorchMnist() model = TorchModel()
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(model, configure_list).compress() torch_compressor.LevelPruner(model, configure_list).compress()
def test_torch_fpgm_pruner(self): @tf2
model = TorchMnist() def test_tf_level_pruner(self):
configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.FPGMPruner(model, configure_list).compress() tf_compressor.LevelPruner(get_tf_model(), configure_list).compress()
def test_torch_quantizer(self): def test_torch_naive_quantizer(self):
model = TorchMnist() model = TorchModel()
configure_list = [{ configure_list = [{
'quant_types': ['weight'], 'quant_types': ['weight'],
'quant_bits': { 'quant_bits': {
...@@ -70,18 +79,59 @@ class CompressorTestCase(TestCase): ...@@ -70,18 +79,59 @@ class CompressorTestCase(TestCase):
torch_compressor.NaiveQuantizer(model, configure_list).compress() torch_compressor.NaiveQuantizer(model, configure_list).compress()
@tf2 @tf2
def test_tf_pruner(self): def test_tf_naive_quantizer(self):
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] tf_compressor.NaiveQuantizer(get_tf_model(), [{'op_types': ['default']}]).compress()
tf_compressor.LevelPruner(get_tf_mnist_model(), configure_list).compress()
@tf2 def test_torch_fpgm_pruner(self):
def test_tf_quantizer(self): """
tf_compressor.NaiveQuantizer(get_tf_mnist_model(), [{'op_types': ['default']}]).compress() With filters(kernels) defined as above (k1 - k5), it is obvious that k3 is the Geometric Median
which minimize the total geometric distance by defination of Geometric Median in this paper:
Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
https://arxiv.org/pdf/1811.00250.pdf
So if sparsity is 0.2, the expected masks should mask out all k3, this can be verified through:
`all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.]))`
If sparsity is 0.6, the expected masks should mask out all k2, k3, k4, this can be verified through:
`all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.]))`
"""
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d']}, {'sparsity': 0.6, 'op_types': ['Conv2d']}]
pruner = torch_compressor.FPGMPruner(model, config_list)
model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0])
assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.]))
pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1])
assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.]))
@tf2 @tf2
def test_tf_fpgm_pruner(self): def test_tf_fpgm_pruner(self):
configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2D']}] model = get_tf_model()
tf_compressor.FPGMPruner(get_tf_mnist_model(), configure_list).compress() config_list = [{'sparsity': 0.2, 'op_types': ['Conv2D']}, {'sparsity': 0.6, 'op_types': ['Conv2D']}]
pruner = tf_compressor.FPGMPruner(model, config_list)
weights = model.layers[2].weights
weights[0] = np.array(w).astype(np.float32).transpose([2, 3, 0, 1]).transpose([0, 1, 3, 2])
model.layers[2].set_weights([weights[0], weights[1].numpy()])
layer = tf_compressor.compressor.LayerInfo(model.layers[2])
masks = pruner.calc_mask(layer, config_list[0]).numpy()
masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3])
assert all(masks.sum((0, 2, 3)) == np.array([90., 90., 0., 90., 90.]))
pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3])
assert all(masks.sum((0, 2, 3)) == np.array([90., 0., 0., 0., 90.]))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -80,8 +80,6 @@ class MsgDispatcherTestCase(TestCase): ...@@ -80,8 +80,6 @@ class MsgDispatcherTestCase(TestCase):
send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":10}') send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":10}')
send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":11}') send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":11}')
send(CommandType.UpdateSearchSpace, '{"name":"SS0"}') send(CommandType.UpdateSearchSpace, '{"name":"SS0"}')
send(CommandType.AddCustomizedTrialJob, '{"param":-1}')
send(CommandType.ReportMetricData, '{"parameter_id":2,"type":"FINAL","value":22}')
send(CommandType.RequestTrialJobs, '1') send(CommandType.RequestTrialJobs, '1')
send(CommandType.KillTrialJob, 'null') send(CommandType.KillTrialJob, 'null')
_restore_io() _restore_io()
...@@ -99,14 +97,7 @@ class MsgDispatcherTestCase(TestCase): ...@@ -99,14 +97,7 @@ class MsgDispatcherTestCase(TestCase):
self._assert_params(0, 2, [], None) self._assert_params(0, 2, [], None)
self._assert_params(1, 4, [], None) self._assert_params(1, 4, [], None)
command, data = receive() # this one is customized self._assert_params(2, 6, [[1, 4, 11, False]], {'name': 'SS0'})
data = json.loads(data)
self.assertIs(command, CommandType.NewTrialJob)
self.assertEqual(data['parameter_id'], 2)
self.assertEqual(data['parameter_source'], 'customized')
self.assertEqual(data['parameters'], {'param': -1})
self._assert_params(3, 6, [[1, 4, 11, False], [2, -1, 22, True]], {'name': 'SS0'})
self.assertEqual(len(_out_buf.read()), 0) # no more commands self.assertEqual(len(_out_buf.read()), 0) # no more commands
......
import * as React from 'react';
import axios from 'axios';
import { Row, Col, Input, Modal, Form, Button, Icon } from 'antd';
import { MANAGER_IP } from '../../static/const';
import { EXPERIMENT, TRIALS } from '../../static/datamodel';
import { FormComponentProps } from 'antd/lib/form';
const FormItem = Form.Item;
import './customized.scss';
interface CustomizeProps extends FormComponentProps {
visible: boolean;
copyTrialId: string;
closeCustomizeModal: () => void;
}
interface CustomizeState {
isShowSubmitSucceed: boolean;
isShowSubmitFailed: boolean;
isShowWarning: boolean;
searchSpace: object;
copyTrialParameter: object; // user click the trial's parameters
customParameters: object; // customized trial, maybe user change trial's parameters
customID: number; // submit customized trial succeed, return the new customized trial id
}
class Customize extends React.Component<CustomizeProps, CustomizeState> {
constructor(props: CustomizeProps) {
super(props);
this.state = {
isShowSubmitSucceed: false,
isShowSubmitFailed: false,
isShowWarning: false,
searchSpace: EXPERIMENT.searchSpace,
copyTrialParameter: {},
customParameters: {},
customID: NaN
};
}
// [submit click] user add a new trial [submit a trial]
addNewTrial = () => {
const { searchSpace, copyTrialParameter } = this.state;
// get user edited hyperParameter, ps: will change data type if you modify the input val
const customized = this.props.form.getFieldsValue();
// true: parameters are wrong
let flag = false;
Object.keys(customized).map(item => {
if (item !== 'tag') {
// unified data type
if (typeof copyTrialParameter[item] === 'number' && typeof customized[item] === 'string') {
customized[item] = JSON.parse(customized[item]);
}
if (searchSpace[item]._type === 'choice') {
if (searchSpace[item]._value.find((val: string | number) =>
val === customized[item]) === undefined) {
flag = true;
return;
}
} else {
if (customized[item] < searchSpace[item]._value[0]
|| customized[item] > searchSpace[item]._value[1]) {
flag = true;
return;
}
}
}
});
if (flag !== false) {
// open the warning modal
this.setState(() => ({ isShowWarning: true, customParameters: customized }));
} else {
// submit a customized job
this.submitCustomize(customized);
}
}
warningConfirm = () => {
this.setState(() => ({ isShowWarning: false }));
const { customParameters } = this.state;
this.submitCustomize(customParameters);
}
warningCancel = () => {
this.setState(() => ({ isShowWarning: false }));
}
submitCustomize = (customized: Object) => {
// delete `tag` key
for (let i in customized) {
if (i === 'tag') {
delete customized[i];
}
}
axios(`${MANAGER_IP}/trial-jobs`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
data: customized
})
.then(res => {
if (res.status === 200) {
this.setState(() => ({ isShowSubmitSucceed: true, customID: res.data.sequenceId }));
this.props.closeCustomizeModal();
} else {
this.setState(() => ({ isShowSubmitFailed: true }));
}
})
.catch(error => {
this.setState(() => ({ isShowSubmitFailed: true }));
});
}
closeSucceedHint = () => {
// also close customized trial modal
this.setState(() => ({ isShowSubmitSucceed: false }));
this.props.closeCustomizeModal();
}
closeFailedHint = () => {
// also close customized trial modal
this.setState(() => ({ isShowSubmitFailed: false }));
this.props.closeCustomizeModal();
}
componentDidMount() {
const { copyTrialId } = this.props;
if (copyTrialId !== undefined && TRIALS.getTrial(copyTrialId) !== undefined) {
const originCopyTrialPara = TRIALS.getTrial(copyTrialId).description.parameters;
this.setState(() => ({ copyTrialParameter: originCopyTrialPara }));
}
}
componentWillReceiveProps(nextProps: CustomizeProps) {
const { copyTrialId } = nextProps;
if (copyTrialId !== undefined && TRIALS.getTrial(copyTrialId) !== undefined) {
const originCopyTrialPara = TRIALS.getTrial(copyTrialId).description.parameters;
this.setState(() => ({ copyTrialParameter: originCopyTrialPara }));
}
}
render() {
const { closeCustomizeModal, visible } = this.props;
const { isShowSubmitSucceed, isShowSubmitFailed, isShowWarning, customID, copyTrialParameter } = this.state;
const {
form: { getFieldDecorator },
// form: { getFieldDecorator, getFieldValue },
} = this.props;
const warning = 'The parameters you set are not in our search space, this may cause the tuner to crash, Are'
+ ' you sure you want to continue submitting?';
return (
<Row>
{/* form: search space */}
<Modal
title="Customized trial setting"
visible={visible}
onCancel={closeCustomizeModal}
footer={null}
destroyOnClose={true}
maskClosable={false}
centered={true}
>
{/* search space form */}
<Row className="hyper-box">
<Form>
{
Object.keys(copyTrialParameter).map(item => (
<Row key={item} className="hyper-form">
<Col span={9} className="title">{item}</Col>
<Col span={15} className="inputs">
<FormItem key={item} style={{ marginBottom: 0 }}>
{getFieldDecorator(item, {
initialValue: copyTrialParameter[item],
})(
<Input />
)}
</FormItem>
</Col>
</Row>
)
)
}
<Row key="tag" className="hyper-form tag-input">
<Col span={9} className="title">Tag</Col>
<Col span={15} className="inputs">
<FormItem key="tag" style={{ marginBottom: 0 }}>
{getFieldDecorator('tag', {
initialValue: 'Customized',
})(
<Input />
)}
</FormItem>
</Col>
</Row>
</Form>
</Row>
<Row className="modal-button">
<Button
type="primary"
className="tableButton distance"
onClick={this.addNewTrial}
>
Submit
</Button>
<Button
className="tableButton cancelSty"
onClick={this.props.closeCustomizeModal}
>
Cancel
</Button>
</Row>
{/* control button */}
</Modal>
{/* clone: prompt succeed or failed */}
<Modal
visible={isShowSubmitSucceed}
footer={null}
destroyOnClose={true}
maskClosable={false}
closable={false}
centered={true}
>
<Row className="resubmit">
<Row>
<h2 className="title">
<span>
<Icon type="check-circle" className="color-succ" />
<b>Submit successfully</b>
</span>
</h2>
<div className="hint">
<span>You can find your customized trial by Trial No.{customID}</span>
</div>
</Row>
<Row className="modal-button">
<Button
className="tableButton cancelSty"
onClick={this.closeSucceedHint}
>
OK
</Button>
</Row>
</Row>
</Modal>
<Modal
visible={isShowSubmitFailed}
footer={null}
destroyOnClose={true}
maskClosable={false}
closable={false}
centered={true}
>
<Row className="resubmit">
<Row>
<h2 className="title">
<span>
<Icon type="check-circle" className="color-error" />Submit Failed
</span>
</h2>
<div className="hint">
<span>Unknown error.</span>
</div>
</Row>
<Row className="modal-button">
<Button
className="tableButton cancelSty"
onClick={this.closeFailedHint}
>
OK
</Button>
</Row>
</Row>
</Modal>
{/* hyperParameter not match search space, warning modal */}
<Modal
visible={isShowWarning}
footer={null}
destroyOnClose={true}
maskClosable={false}
closable={false}
centered={true}
>
<Row className="resubmit">
<Row>
<h2 className="title">
<span>
<Icon className="color-warn" type="warning" />Warning
</span>
</h2>
<div className="hint">
<span>{warning}</span>
</div>
</Row>
<Row className="modal-button center">
<Button
className="tableButton cancelSty distance"
onClick={this.warningConfirm}
>
Confirm
</Button>
<Button
className="tableButton cancelSty"
onClick={this.warningCancel}
>
Cancel
</Button>
</Row>
</Row>
</Modal>
</Row>
);
}
}
export default Form.create<FormComponentProps>()(Customize);
\ No newline at end of file
.ant-modal-body{
border-radius: none;
}
.ant-modal-title {
font-size: 18px;
}
/* resubmit confirm modal style */
.resubmit{
.title{
font-size: 16px;
color: #000;
.color-warn, .color-error{
color: red;
}
i{
margin-right: 10px;
}
}
.hint{
padding: 15px 0;
color: #333;
margin-left: 30px;
}
.color-succ{
color: green;
}
}
.hyper-box{
padding: 16px 18px 16px 16px;
}
.hyper-form{
height: 32px;
margin-bottom: 8px;
.title{
font-size: 14px;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
line-height: 32px;
}
.inputs{
height: 32px;
}
input{
height: 32px;
}
}
.tag-input{
margin-top: 25px;
}
/* submit & cancel buttons style*/
.modal-button{
text-align: right;
height: 28px;
/* cancel button style*/
.cancelSty{
width: 80px;
background-color: #dadada;
border: none;
color: #333;
}
.cancelSty:hover, .cancelSty:active, .cancelSty:focus{
background-color: #dadada;
}
.distance{
margin-right: 8px;
}
}
.center{
text-align: center;
}
...@@ -7,10 +7,11 @@ const Option = Select.Option; ...@@ -7,10 +7,11 @@ const Option = Select.Option;
const CheckboxGroup = Checkbox.Group; const CheckboxGroup = Checkbox.Group;
import { MANAGER_IP, trialJobStatus, COLUMN_INDEX, COLUMNPro } from '../../static/const'; import { MANAGER_IP, trialJobStatus, COLUMN_INDEX, COLUMNPro } from '../../static/const';
import { convertDuration, formatTimestamp, intermediateGraphOption, killJob } from '../../static/function'; import { convertDuration, formatTimestamp, intermediateGraphOption, killJob } from '../../static/function';
import { TRIALS } from '../../static/datamodel'; import { EXPERIMENT, TRIALS } from '../../static/datamodel';
import { TableRecord } from '../../static/interface'; import { TableRecord } from '../../static/interface';
import OpenRow from '../public-child/OpenRow'; import OpenRow from '../public-child/OpenRow';
import Compare from '../Modal/Compare'; import Compare from '../Modal/Compare';
import Customize from '../Modal/CustomizedTrial';
import '../../static/style/search.scss'; import '../../static/style/search.scss';
require('../../static/style/tableStatus.css'); require('../../static/style/tableStatus.css');
require('../../static/style/logPath.scss'); require('../../static/style/logPath.scss');
...@@ -45,6 +46,8 @@ interface TableListState { ...@@ -45,6 +46,8 @@ interface TableListState {
intermediateData: Array<object>; // a trial's intermediate results (include dict) intermediateData: Array<object>; // a trial's intermediate results (include dict)
intermediateId: string; intermediateId: string;
intermediateOtherKeys: Array<string>; intermediateOtherKeys: Array<string>;
isShowCustomizedModal: boolean;
copyTrialId: string; // user copy trial to submit a new customized trial
} }
interface ColumnIndex { interface ColumnIndex {
...@@ -71,7 +74,9 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -71,7 +74,9 @@ class TableList extends React.Component<TableListProps, TableListState> {
selectedRowKeys: [], // close selected trial message after modal closed selectedRowKeys: [], // close selected trial message after modal closed
intermediateData: [], intermediateData: [],
intermediateId: '', intermediateId: '',
intermediateOtherKeys: [] intermediateOtherKeys: [],
isShowCustomizedModal: false,
copyTrialId: ''
}; };
} }
...@@ -236,17 +241,36 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -236,17 +241,36 @@ class TableList extends React.Component<TableListProps, TableListState> {
this.setState({ isShowCompareModal: false, selectedRowKeys: [], selectRows: [] }); this.setState({ isShowCompareModal: false, selectedRowKeys: [], selectRows: [] });
} }
// open customized trial modal
setCustomizedTrial = (trialId: string) => {
this.setState({
isShowCustomizedModal: true,
copyTrialId: trialId
});
}
closeCustomizedTrial = () => {
this.setState({
isShowCustomizedModal: false,
copyTrialId: ''
});
}
render() { render() {
const { pageSize, columnList } = this.props; const { pageSize, columnList } = this.props;
const tableSource: Array<TableRecord> = JSON.parse(JSON.stringify(this.props.tableSource)); const tableSource: Array<TableRecord> = JSON.parse(JSON.stringify(this.props.tableSource));
const { intermediateOption, modalVisible, isShowColumn, const { intermediateOption, modalVisible, isShowColumn,
selectRows, isShowCompareModal, selectedRowKeys, intermediateOtherKeys } = this.state; selectRows, isShowCompareModal, selectedRowKeys, intermediateOtherKeys,
isShowCustomizedModal, copyTrialId
} = this.state;
const rowSelection = { const rowSelection = {
selectedRowKeys: selectedRowKeys, selectedRowKeys: selectedRowKeys,
onChange: (selected: string[] | number[], selectedRows: Array<TableRecord>) => { onChange: (selected: string[] | number[], selectedRows: Array<TableRecord>) => {
this.fillSelectedRowsTostate(selected, selectedRows); this.fillSelectedRowsTostate(selected, selectedRows);
} }
}; };
// [supportCustomizedTrial: true]
const supportCustomizedTrial = (EXPERIMENT.multiPhase === true) ? false : true;
const disabledAddCustomizedTrial = ['DONE', 'ERROR', 'STOPPED'].includes(EXPERIMENT.status);
let showTitle = COLUMNPro; let showTitle = COLUMNPro;
const showColumn: Array<object> = []; const showColumn: Array<object> = [];
...@@ -361,6 +385,22 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -361,6 +385,22 @@ class TableList extends React.Component<TableListProps, TableListState> {
</Button> </Button>
</Popconfirm> </Popconfirm>
} }
{/* Add a new trial-customized trial */}
{
supportCustomizedTrial
?
<Button
type="primary"
className="common-style"
disabled={disabledAddCustomizedTrial}
onClick={this.setCustomizedTrial.bind(this, record.id)}
title="Customized trial"
>
<Icon type="copy" />
</Button>
:
null
}
</Row> </Row>
); );
}, },
...@@ -398,7 +438,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -398,7 +438,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
expandedRowRender={this.openRow} expandedRowRender={this.openRow}
dataSource={tableSource} dataSource={tableSource}
className="commonTableStyle" className="commonTableStyle"
scroll={{x: 'max-content'}} scroll={{ x: 'max-content' }}
pagination={pageSize > 0 ? { pageSize } : false} pagination={pageSize > 0 ? { pageSize } : false}
/> />
{/* Intermediate Result Modal */} {/* Intermediate Result Modal */}
...@@ -458,7 +498,14 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -458,7 +498,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
className="titleColumn" className="titleColumn"
/> />
</Modal> </Modal>
{/* compare trials based message */}
<Compare compareRows={selectRows} visible={isShowCompareModal} cancelFunc={this.hideCompareModal} /> <Compare compareRows={selectRows} visible={isShowCompareModal} cancelFunc={this.hideCompareModal} />
{/* clone trial parameters and could submit a customized trial */}
<Customize
visible={isShowCustomizedModal}
copyTrialId={copyTrialId}
closeCustomizeModal={this.closeCustomizedTrial}
/>
</Row> </Row>
); );
} }
......
...@@ -35,7 +35,7 @@ def check_ready_to_run(): ...@@ -35,7 +35,7 @@ def check_ready_to_run():
pidList.remove(os.getpid()) pidList.remove(os.getpid())
return not pidList return not pidList
else: else:
pgrep_output = subprocess.check_output('pgrep -fx \'python3 -m nni_gpu_tool.gpu_metrics_collector\'', shell=True) pgrep_output = subprocess.check_output('pgrep -fxu "$(whoami)" \'python3 -m nni_gpu_tool.gpu_metrics_collector\'', shell=True)
pidList = [] pidList = []
for pid in pgrep_output.splitlines(): for pid in pgrep_output.splitlines():
pidList.append(int(pid)) pidList.append(int(pid))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment