Unverified Commit 9987014e authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

NAS visualization (#2085)

parent ede59380
......@@ -70,6 +70,8 @@ build:
cp -rf src/nni_manager/config src/nni_manager/dist/
#$(_INFO) Building WebUI $(_END)
cd src/webui && $(NNI_YARN) && $(NNI_YARN) build
#$(_INFO) Building NAS UI $(_END)
cd src/nasui && $(NNI_YARN) && $(NNI_YARN) build
# All-in-one target for non-expert users
# Installs NNI as well as its dependencies, and update bashrc to set PATH
......
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
# dependencies
/node_modules
/.pnp
.pnp.js
# testing
/coverage
# production
/build
# misc
.DS_Store
.env.local
.env.development.local
.env.test.local
.env.production.local
npm-debug.log*
yarn-debug.log*
yarn-error.log*
{
"name": "nasnni-vis-ts",
"version": "0.1.0",
"private": true,
"dependencies": {
"@material-ui/core": "^4.9.3",
"@material-ui/icons": "^4.9.1",
"cytoscape": "^3.14.0",
"cytoscape-dagre": "^2.2.2",
"cytoscape-panzoom": "^2.5.3",
"express": "^4.17.1",
"lodash": "^4.17.15",
"react": "^16.12.0",
"react-dom": "^16.12.0",
"react-scripts": "3.4.0",
"typeface-roboto": "^0.0.75",
"typescript": "~3.7.2"
},
"scripts": {
"start": "react-scripts start",
"build": "react-scripts build",
"eject": "react-scripts eject",
"backend": "node server.js"
},
"eslintConfig": {
"extends": "react-app"
},
"browserslist": {
"production": [
">0.2%",
"not dead",
"not op_mini all"
],
"development": [
"last 1 chrome version",
"last 1 firefox version",
"last 1 safari version"
]
},
"devDependencies": {
"@testing-library/jest-dom": "^4.2.4",
"@testing-library/react": "^9.3.2",
"@testing-library/user-event": "^7.1.2",
"@types/cytoscape": "^3.14.0",
"@types/jest": "^24.0.0",
"@types/lodash": "^4.14.149",
"@types/node": "^12.0.0",
"@types/react": "^16.9.0",
"@types/react-dom": "^16.9.0"
},
"proxy": "http://localhost:8080"
}
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<link rel="icon" href="%PUBLIC_URL%/icon.png" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="theme-color" content="#000000" />
<title>NNI NAS Board</title>
</head>
<body>
<noscript>You need to enable JavaScript to run this app.</noscript>
<div id="root"></div>
<!--
This HTML file is a template.
If you open it directly in the browser, you will see an empty page.
You can add webfonts, meta tags, or analytics to this file.
The build step will place the bundled scripts into the <body> tag.
To begin the development, run `npm start` or `yarn start`.
To create a production bundle, use `npm run build` or `yarn build`.
-->
</body>
</html>
const express = require('express');
const path = require('path');
const fs = require('fs');
const app = express();
const argv = require('minimist')(process.argv.slice(2));
const port = argv.port || 8080;
const logdir = argv.logdir || './mockdata';
app.use(express.static(path.join(__dirname, 'build')));
app.get('/', (req, res) => {
res.sendFile(path.join(__dirname, 'build', 'index.html'));
});
app.get('/refresh', (req, res) => {
const graph = fs.readFileSync(path.join(logdir, 'graph.json'), 'utf8');
const log = fs.readFileSync(path.join(logdir, 'log'), 'utf-8')
.split('\n')
.filter(Boolean)
.map(JSON.parse);
res.send({
'graph': JSON.parse(graph),
'log': log,
});
});
app.listen(port, '0.0.0.0', () => {
console.log(`NNI NAS board is running on port ${port}, logdir is ${logdir}.`);
});
svg {
overflow: hidden;
}
.node {
white-space: nowrap;
}
.node.input rect, .node.output rect {
fill: #9d0f0f;
}
.node.hub rect {
fill: #0f309d;
}
.node.blob rect {
fill: #0F9D58;
}
.node text {
fill: #fff;
font-family: "Roboto", "Helvetica", "Arial", sans-serif;
font-weight: 500;
}
.cluster rect {
stroke: #333;
fill: #000;
fill-opacity: 0.1;
stroke-width: 1.5px;
}
.edgePath path.path {
stroke: #333;
stroke-width: 1.5px;
fill: none;
}
import React, { ChangeEvent } from 'react';
import './App.css';
import 'typeface-roboto';
import { createStyles, withStyles } from '@material-ui/core/styles';
import AppBar from '@material-ui/core/AppBar';
import Button from '@material-ui/core/Button';
import Grid from '@material-ui/core/Grid';
import IconButton from '@material-ui/core/IconButton';
import Slider from '@material-ui/core/Slider';
import Toolbar from '@material-ui/core/Toolbar';
import Typography from '@material-ui/core/Typography';
import RefreshIcon from '@material-ui/icons/Refresh';
import SettingsIcon from '@material-ui/icons/Settings';
import CloseIcon from '@material-ui/icons/Close';
import ShuffleIcon from '@material-ui/icons/Shuffle';
import Snackbar from '@material-ui/core/Snackbar';
import FormControl from '@material-ui/core/FormControl';
import FormControlLabel from '@material-ui/core/FormControlLabel';
import FormGroup from '@material-ui/core/FormGroup';
import Checkbox from '@material-ui/core/Checkbox';
import Dialog from '@material-ui/core/Dialog';
import DialogTitle from '@material-ui/core/DialogTitle';
import DialogContent from '@material-ui/core/DialogContent';
import DialogActions from '@material-ui/core/DialogActions';
import MuiExpansionPanel from '@material-ui/core/ExpansionPanel';
import MuiExpansionPanelSummary from '@material-ui/core/ExpansionPanelSummary';
import MuiExpansionPanelDetails from '@material-ui/core/ExpansionPanelDetails';
import ExpandMoreIcon from '@material-ui/icons/ExpandMore';
import List from '@material-ui/core/List';
import ListItem from '@material-ui/core/ListItem';
import Backdrop from '@material-ui/core/Backdrop';
import Chart from './Chart';
import { Graph } from './graphUtils';
const styles = createStyles({
bottomAppBar: {
top: 'auto',
bottom: 0,
zIndex: 'auto',
},
title: {
flexGrow: 1,
textAlign: 'left'
},
panel: {
position: 'absolute',
top: 76,
right: 16,
width: 400,
},
listItem: {
paddingLeft: 0,
paddingRight: 0,
paddingTop: 2,
paddingBottom: 2,
fontSize: '0.8em',
wordBreak: 'break-all',
},
listSubtitle: {
fontWeight: 600,
paddingLeft: 0,
paddingRight: 0,
fontSize: '0.9em',
},
listTitle: {
lineHeight: 1.1,
wordBreak: 'break-all'
},
backdrop: {
color: '#fff',
zIndex: 100,
},
snackbar: {
bottom: 76
}
});
const ExpansionPanel = withStyles({
root: {
'&$expanded': {
margin: 'auto',
},
},
expanded: {},
})(MuiExpansionPanel);
const ExpansionPanelSummary = withStyles({
root: {},
content: {
'&$expanded': {
margin: '12px 0',
},
},
expanded: {},
})(MuiExpansionPanelSummary);
const ExpansionPanelDetails = withStyles(theme => ({
root: {
paddingTop: 0,
paddingBottom: theme.spacing(1),
},
}))(MuiExpansionPanelDetails);
type AppState = {
graph: Graph | undefined,
graphData: any,
logData: any[],
sliderValue: number,
maxSliderValue: number,
sliderStep: number,
settingsOpen: boolean,
hideSidechainNodes: boolean,
hidePrimitiveNodes: boolean,
snackbarOpen: boolean,
selectedNode: string,
loading: boolean,
layout: boolean,
}
type AppProps = {
classes: any
}
class App extends React.Component<AppProps, AppState> {
constructor(props: any) {
super(props);
this.state = {
graph: undefined,
graphData: undefined,
logData: [],
sliderValue: 0,
maxSliderValue: 0,
sliderStep: 1,
settingsOpen: false,
hideSidechainNodes: true,
hidePrimitiveNodes: true,
selectedNode: '',
loading: false,
snackbarOpen: false,
layout: false,
};
this.refresh = this.refresh.bind(this);
}
componentDidMount() {
this.refresh();
}
refresh() {
this.setState({ loading: true });
fetch('/refresh')
.then((response) => { return response.json() })
.then((data) => {
const graph = new Graph(data.graph, this.state.hideSidechainNodes);
this.setState({
graphData: data.graph,
graph: graph,
logData: data.log,
maxSliderValue: data.log.length - 1,
sliderStep: Math.max(1, Math.floor(data.log.length / 20)),
sliderValue: Math.min(data.log.length, this.state.sliderValue),
loading: false,
snackbarOpen: graph.nodes.length > 100
});
});
}
private renderExpansionPanel() {
const { classes } = this.props;
const { selectedNode, graph } = this.state;
if (graph === undefined)
return null;
const info = graph.nodeSummary(selectedNode);
if (info === undefined)
return null;
const subtitle = info.op ?
(info.op === 'IO Node' ? info.op : `Operation: ${info.op}`) :
`Subgraph: ${info.nodeCount} nodes, ${info.edgeCount} edges`;
return (
<ExpansionPanel className={classes.panel}>
<ExpansionPanelSummary
expandIcon={<ExpandMoreIcon />}
>
<Typography variant='subtitle1' className={classes.listTitle}><b>{info.name}</b><br />{subtitle}</Typography>
</ExpansionPanelSummary>
<ExpansionPanelDetails>
<List dense={true} style={{
maxHeight: window.innerHeight * .5,
overflowY: 'auto',
paddingTop: 0,
width: '100%'
}}>
{
info.attributes &&
<React.Fragment>
<ListItem className={classes.listSubtitle}>Attributes</ListItem>
<ListItem className={classes.listItem}>{info.attributes}</ListItem>
</React.Fragment>
}
{
info.inputs.length > 0 &&
<React.Fragment>
<ListItem className={classes.listSubtitle}>Inputs ({info.inputs.length})</ListItem>
{
info.inputs.map((item, i) => <ListItem className={classes.listItem} key={`input${i}`}>{item}</ListItem>)
}
</React.Fragment>
}
{
info.outputs.length > 0 &&
<React.Fragment>
<ListItem className={classes.listSubtitle}>Outputs ({info.outputs.length})</ListItem>
{
info.outputs.map((item, i) => <ListItem className={classes.listItem} key={`output${i}`}>{item}</ListItem>)
}
</React.Fragment>
}
</List>
</ExpansionPanelDetails>
</ExpansionPanel>
);
}
render() {
const { classes } = this.props;
const { sliderValue, maxSliderValue, sliderStep, settingsOpen, loading, snackbarOpen } = this.state;
const handleSliderChange = (event: ChangeEvent<{}>, value: number | number[]) => {
this.setState({ sliderValue: value as number });
};
const handleSettingsDialogToggle = (value: boolean) => () => {
this.setState({ settingsOpen: value });
};
const handleSettingsChange = (name: string) => (event: React.ChangeEvent<HTMLInputElement>) => {
this.setState({
...this.state,
[name]: event.target.checked
}, () => {
this.setState({
graph: new Graph(this.state.graphData, this.state.hideSidechainNodes),
})
});
};
const handleSelectionChange = (node: string) => {
this.setState({
selectedNode: node
});
};
const handleLoadingState = (state: boolean) => () => {
this.setState({ loading: state });
};
const handleSnackbarClose = () => {
this.setState({ snackbarOpen: false });
};
const handleLayoutStateChanged = (state: boolean) => () => {
this.setState({ layout: state });
};
return (
<div className='App'>
<Chart
width={window.innerWidth}
height={window.innerHeight}
graph={this.state.graph}
activation={sliderValue < this.state.logData.length ? this.state.logData[sliderValue] : undefined}
handleSelectionChange={handleSelectionChange}
onRefresh={handleLoadingState(true)}
onRefreshComplete={handleLoadingState(false)}
layout={this.state.layout}
onLayoutComplete={handleLayoutStateChanged(false)}
/>
<AppBar position='fixed' color='primary'>
<Toolbar>
<Typography variant='h6' className={classes.title}>
NNI NAS Board
</Typography>
<IconButton color='inherit' onClick={handleLayoutStateChanged(true)}>
<ShuffleIcon />
</IconButton>
<IconButton color='inherit' onClick={this.refresh}>
<RefreshIcon />
</IconButton>
<IconButton color='inherit' onClick={handleSettingsDialogToggle(true)}>
<SettingsIcon />
</IconButton>
</Toolbar>
</AppBar>
<AppBar position='fixed' color='default' className={classes.bottomAppBar}>
<Toolbar variant='dense'>
<Grid container spacing={2} alignItems='center'>
<Grid item xs>
<Slider
value={sliderValue}
max={maxSliderValue}
min={0}
step={sliderStep}
onChange={handleSliderChange}
/>
</Grid>
<Grid item>
<Typography variant='body1'>
{sliderValue}/{maxSliderValue}
</Typography>
</Grid>
</Grid>
</Toolbar>
</AppBar>
<Dialog onClose={handleSettingsDialogToggle(false)} open={settingsOpen}>
<DialogTitle>Settings</DialogTitle>
<DialogContent>
<FormControl component='fieldset'>
<FormGroup>
<FormControlLabel
control={<Checkbox checked={this.state.hideSidechainNodes}
onChange={handleSettingsChange('hideSidechainNodes')}
value='hideSidechainNodes' />}
label='Hide sidechain nodes'
/>
{ // TODO: hide primitive nodes
/* <FormControlLabel
control={<Checkbox checked={this.state.hidePrimitiveNodes}
onChange={handleSettingsChange('hidePrimitiveNodes')}
value='hidePrimitiveNodes' />}
label='Hide primitive nodes'
/> */}
</FormGroup>
</FormControl>
</DialogContent>
<DialogActions>
<Button onClick={handleSettingsDialogToggle(false)} color='primary'>
Close
</Button>
</DialogActions>
</Dialog>
{this.renderExpansionPanel()}
<Snackbar
className={classes.snackbar}
anchorOrigin={{
vertical: 'bottom',
horizontal: 'left',
}}
open={snackbarOpen}
message='Graph is too large. Might induce performance issue.'
onClose={handleSnackbarClose}
action={
<IconButton size='small' color='inherit' onClick={handleSnackbarClose}>
<CloseIcon fontSize='small' />
</IconButton>
}
/>
{
loading && <Backdrop className={classes.backdrop} open={true}>
<Typography>Loading...</Typography>
</Backdrop>
}
</div>
);
}
}
export default withStyles(styles)(App);
import React, { createRef } from 'react';
import cytoscape from 'cytoscape';
import dagre from 'cytoscape-dagre';
import lodash from 'lodash';
import { Graph, NodeTs } from './graphUtils';
cytoscape.use(dagre);
type ChartProps = {
width: number,
height: number,
graph: Graph | undefined,
activation: any,
handleSelectionChange: (_: string) => void,
onRefresh: () => void,
onRefreshComplete: () => void,
layout: boolean,
onLayoutComplete: () => void,
}
const styles = [
{
selector: ':active',
style: {
'overlay-opacity': 0.1,
}
},
{
selector: 'node',
style: {
'shape': 'round-rectangle',
'width': 'label',
'height': 'label',
'content': 'data(label)',
'text-wrap': 'wrap',
'text-valign': 'center',
'text-halign': 'center',
'font-family': '"Roboto", "Helvetica", "Arial", sans-serif',
'font-size': 20,
'padding-left': '8px',
'padding-right': '8px',
'padding-top': '8px',
'padding-bottom': '8px',
'background-color': '#000',
'background-opacity': 0.02,
'border-color': '#555',
'border-width': '2px',
}
},
{
selector: 'node.op',
style: {
'background-color': '#fafafa',
'background-opacity': 1,
}
},
{
selector: 'node:selected',
style: {
'border-width': '4px',
'border-color': '#101010',
}
},
{
selector: 'edge',
style: {
'curve-style': 'bezier',
'target-arrow-shape': 'triangle',
'line-color': '#555',
'target-arrow-color': '#555'
}
},
{
selector: 'node.compound',
style: {
'shape': 'roundrectangle',
'text-valign': 'top',
'padding-top': '30px',
'padding-bottom': '10px',
'padding-right': '10px',
'padding-left': '10px',
}
}
];
export default class Chart extends React.Component<ChartProps> {
container = createRef<HTMLDivElement>();
zoom: any;
collapseInstance: any;
cyInstance: cytoscape.Core | null = null;
expandSet: Set<string> = new Set<string>();
elemWeight: Map<string, number> = new Map<string, number>();
graphEl: any[] = [];
componentDidMount() {
this.cyInstance = cytoscape({
container: this.container.current,
elements: [],
style: styles as any,
wheelSensitivity: 0.1,
});
let singleClickedNodes = new Set<string>();
this.cyInstance!.off('click');
this.cyInstance!.on('click', (e: cytoscape.EventObject) => {
if (e.target.id) {
let nodeId = e.target.id();
if (singleClickedNodes.has(nodeId)) {
singleClickedNodes.delete(nodeId);
e.target.trigger('dblclick');
} else {
singleClickedNodes.add(nodeId);
setTimeout(() => { singleClickedNodes.delete(nodeId); }, 300);
}
}
});
this.cyInstance!.on('dblclick', (e: cytoscape.EventObject) => {
const nodeId = e.target.id();
const node = this.props.graph!.getNodeById(nodeId);
if (node !== undefined && node.isParent()) {
if (this.expandSet.has(nodeId)) {
this.expandSet.delete(nodeId);
} else {
this.expandSet.add(nodeId);
}
this.renderGraph(false);
}
});
this.cyInstance!.on('select', (e: cytoscape.EventObject) => {
this.props.handleSelectionChange(e.target.id());
});
this.renderGraph(true);
}
componentDidUpdate(prevProps: ChartProps) {
if (prevProps.graph === this.props.graph) {
if (prevProps.width === this.props.width &&
prevProps.height === this.props.height) {
if (this.props.layout) {
this.props.onLayoutComplete();
this.reLayout();
}
if (prevProps.activation !== this.props.activation) {
// perhaps only display step is changed
this.applyMutableWeight();
}
} else {
// something changed, re-render
this.renderGraph(false);
}
} else {
// re-calculate collapse
this.renderGraph(true);
}
}
private graphElements() {
let graphElements: any[] = [];
const { graph } = this.props;
if (graph === undefined)
return [];
const collapseMap = new Map<string, string>();
const traverse = (node: NodeTs, top: NodeTs | undefined) => {
collapseMap.set(node.id, top === undefined ? node.id : top.id);
if (node.id && (top === undefined || node === top)) {
// not root and will display
const isCompound = node.isParent() && this.expandSet.has(node.id);
let data: any = {
id: node.id,
label: node.op ? node.op : node.tail,
};
if (node.parent !== undefined)
data.parent = node.parent.id;
const classes = [];
if (isCompound) classes.push('compound');
if (node.op) classes.push('op');
graphElements.push({
data: data,
classes: classes
});
}
for (const child of node.children)
traverse(child, node.id && top === undefined && !this.expandSet.has(node.id) ? node : top);
}
traverse(graph.root, undefined);
graph.edges.forEach(edge => {
const [srcCollapse, trgCollapse] = [edge.source, edge.target].map((node) => collapseMap.get(node.id)!);
if (edge.source !== edge.target && srcCollapse === trgCollapse) {
return;
}
graphElements.push({
data: {
id: edge.id,
source: srcCollapse,
target: trgCollapse
}
});
});
return graphElements;
}
private applyMutableWeight() {
const { graph, activation } = this.props;
if (graph === undefined || activation === undefined)
return;
const weights = graph.weightFromMutables(activation);
weights.forEach((weight, elem) => {
if (this.elemWeight.get(elem) !== weight) {
this.cyInstance!.getElementById(elem).style({
opacity: 0.2 + 0.8 * weight
});
}
});
this.elemWeight = weights;
}
private graphElDifference(prev: any[], next: any[]): [Set<string>, any] {
const tracedElements = new Set(prev.map(ele => ele.data.id));
const prevMap = new Map(prev.map(ele => [ele.data.id, ele]));
const nextMap = new Map(next.map(ele => [ele.data.id, ele]));
const addedEles: any = [];
nextMap.forEach((val, k) => {
const prevEle = prevMap.get(k);
if (prevEle === undefined) {
addedEles.push(val);
} else if (!lodash.isEqual(val, prevEle)) {
tracedElements.delete(k);
addedEles.push(val);
} else {
tracedElements.delete(k);
}
});
return [tracedElements, addedEles];
}
private reLayout() {
this.props.onRefresh();
const _render = () => {
const layout: any = {
name: 'dagre'
};
this.cyInstance!.layout(layout).run();
this.props.onRefreshComplete();
};
setTimeout(_render, 100);
}
private renderGraph(graphChanged: boolean) {
const { graph } = this.props;
if (graph === undefined)
return;
this.props.onRefresh();
const _render = () => {
if (graphChanged)
this.expandSet = lodash.cloneDeep(graph.defaultExpandSet);
const graphEl = this.graphElements();
const [remove, add] = this.graphElDifference(this.graphEl, graphEl);
const layout: any = {
name: 'dagre'
};
if (graphEl.length > 100) {
if (remove.size > 0) {
const removedEles = this.cyInstance!.elements().filter(ele => remove.has(ele.id()));
this.cyInstance!.remove(removedEles);
}
if (add.length > 0) {
const eles = this.cyInstance!.add(add);
this.cyInstance!.json({
elements: graphEl
});
layout.fit = false;
eles.layout(layout).run();
}
} else {
this.cyInstance!.json({
elements: graphEl
});
this.cyInstance!.layout(layout).run();
}
this.applyMutableWeight();
this.graphEl = graphEl;
this.props.onRefreshComplete();
};
if (graph.nodes.length > 100)
setTimeout(_render, 100);
else
_render();
}
render() {
return (
<div className='container' ref={this.container}
style={{
left: 0,
top: 0,
position: 'absolute',
width: this.props.width - 15,
height: this.props.height,
overflow: 'hidden'
}}>
</div>
);
}
}
function path2module(path: any[]): string {
return path.map(p => p.name ? `${p.type}[${p.name}]` : p.type).join('/');
}
function opName(rawName: string): string {
if (rawName.includes('::')) {
return rawName.split('::')[1];
} else {
return rawName;
}
}
export class NodeTs {
readonly id: string;
readonly tail: string;
parent: NodeTs | undefined;
children: NodeTs[];
op: string;
attributes: string;
constructor(id: string, tail: string, op: string, attributes: string) {
this.children = [];
this.id = id;
this.tail = tail;
this.parent = undefined;
this.op = op;
this.attributes = attributes;
}
descendants(leafOnly: boolean): NodeTs[] {
// return all descendants includinng itself
const result: NodeTs[] = [];
if (!leafOnly || this.isChildless())
result.push(this);
for (const child of this.children) {
const childDesc = child.descendants(leafOnly);
if (childDesc.length > 0)
result.push(...childDesc);
}
return result;
}
isChildless(): boolean {
return this.children.length === 0;
}
isParent(): boolean {
return this.children.length > 0;
}
};
export class Edge {
readonly source: NodeTs;
readonly target: NodeTs;
readonly id: string;
constructor(source: NodeTs, target: NodeTs) {
this.source = source;
this.target = target;
this.id = JSON.stringify([this.source.id, this.target.id]);
}
};
interface NodeSummary {
name: string,
nodeCount: number,
edgeCount: number,
inputs: string[],
outputs: string[],
attributes: string,
op: string
};
export class Graph {
root: NodeTs;
nodes: NodeTs[];
edges: Edge[];
defaultExpandSet: Set<string>;
private id2idx: Map<string, number>;
private edgeId2idx: Map<string, number>;
private forwardGraph: Map<string, string[]>;
private backwardGraph: Map<string, string[]>;
private node2edge: Map<string, Edge[]>;
private mutableEdges: Map<string, Edge[][]>;
private build() {
this.id2idx.clear();
this.nodes.forEach((node, i) => {
this.id2idx.set(node.id, i);
});
this.edgeId2idx.clear();
this.edges.forEach((edge, i) => {
this.edgeId2idx.set(edge.id, i);
});
this.forwardGraph.clear();
this.backwardGraph.clear();
this.node2edge.clear();
this.edges.forEach(edge => {
if (!this.forwardGraph.has(edge.source.id))
this.forwardGraph.set(edge.source.id, []);
this.forwardGraph.get(edge.source.id)!.push(edge.target.id);
if (!this.backwardGraph.has(edge.target.id))
this.backwardGraph.set(edge.target.id, []);
this.backwardGraph.get(edge.target.id)!.push(edge.source.id);
if (!this.node2edge.has(edge.source.id))
this.node2edge.set(edge.source.id, []);
if (!this.node2edge.has(edge.target.id))
this.node2edge.set(edge.target.id, []);
this.node2edge.get(edge.source.id)!.push(edge);
this.node2edge.get(edge.target.id)!.push(edge);
});
this.root.children = this.nodes.filter(node => node.parent === undefined);
// won't set parent for these nodes, leave them as undefined
this.nodes.forEach(node => {
node.children = node.children.filter(child => this.getNodeById(child.id) !== undefined);
});
}
getNodeById(id: string): NodeTs | undefined {
const idx = this.id2idx.get(id);
if (idx === undefined) return undefined;
return this.nodes[idx];
}
getEdgeById(source: string, target: string): Edge | undefined {
const idx = this.edgeId2idx.get(JSON.stringify([source, target]));
if (idx === undefined) return undefined;
return this.edges[idx];
}
constructor(graphData: any, eliminateSidechains: boolean) {
this.id2idx = new Map<string, number>();
this.edgeId2idx = new Map<string, number>();
this.forwardGraph = new Map<string, string[]>();
this.backwardGraph = new Map<string, string[]>();
this.node2edge = new Map<string, Edge[]>();
this.root = new NodeTs('', '', '', '');
const cluster = new Map<string, NodeTs>();
const parentMap = new Map<string, string>();
this.nodes = graphData.node.map((node: any): NodeTs => {
const split = node.name.split('/');
const attr = node.hasOwnProperty('attr') ? atob(node.attr.attr.s) : '';
if (split.length === 1) {
return new NodeTs(node.name, node.name, opName(node.op), attr);
} else {
parentMap.set(node.name, split.slice(0, -1).join('/'));
// create clusters
for (let i = 1; i < split.length; ++i) {
const name = split.slice(0, i).join('/');
if (!cluster.has(name)) {
const parent = i > 1 ? split.slice(0, i - 1).join('/') : '';
const tail = split[i - 1];
cluster.set(name, new NodeTs(name, tail, '', ''));
parentMap.set(name, parent);
}
}
return new NodeTs(node.name, split.slice(-1)[0], opName(node.op), attr);
}
});
cluster.forEach(node => this.nodes.push(node));
this.nodes.forEach((node, i) => {
this.id2idx.set(node.id, i);
});
parentMap.forEach((parent, child) => {
const [childNode, parentNode] = [child, parent].map(this.getNodeById.bind(this));
if (childNode !== undefined && parentNode !== undefined) {
childNode.parent = parentNode;
parentNode.children.push(childNode);
}
});
// build edges
this.edges = [];
graphData.node.forEach((node: any) => {
if (!node.hasOwnProperty('input')) return;
const target = this.getNodeById(node.name);
if (target === undefined) return;
node.input.forEach((input: string) => {
const source = this.getNodeById(input);
if (source !== undefined) {
this.edges.push(new Edge(source, target));
}
})
})
this.build();
if (eliminateSidechains) {
this.eliminateSidechains();
}
this.defaultExpandSet = this.getDefaultExpandSet(graphData.mutable);
this.mutableEdges = this.inferMutableEdges(graphData.mutable);
}
private eliminateSidechains(): void {
const sources = this.nodes
.map(node => node.id)
.filter(id => id.startsWith('input'));
const visitedNodes = new Set(sources);
const dfsStack = sources;
while (dfsStack.length > 0) {
const u = dfsStack.pop()!;
if (this.forwardGraph.has(u)) {
this.forwardGraph.get(u)!.forEach((v: string) => {
if (!visitedNodes.has(v)) {
visitedNodes.add(v);
dfsStack.push(v);
}
});
}
}
const compoundCheck = (node: NodeTs) => {
if (node.isChildless())
return visitedNodes.has(node.id);
for (const child of node.children)
if (compoundCheck(child))
visitedNodes.add(node.id);
return visitedNodes.has(node.id);
}
compoundCheck(this.root);
this.nodes = this.nodes.filter(node => visitedNodes.has(node.id));
this.edges = this.edges.filter(edge =>
visitedNodes.has(edge.source.id) && visitedNodes.has(edge.target.id));
this.build();
}
private getDefaultExpandSet(graphDataMutable: any): Set<string> {
// if multiple, only expand first
const whitelistModuleList = Object.values(graphDataMutable)
.filter(Boolean)
.map((paths: any) => path2module(paths[0]));
const whitelistModule = new Set(whitelistModuleList);
const result = new Set<string>();
const dfs = (node: NodeTs): number => {
// node with mutableCount greater than 0 won't be collapsed
let mutableCount = 0;
if (node.id === '') {
// root node
mutableCount++;
} else if (whitelistModule.has(node.id)) {
mutableCount++;
} else if (node.parent !== undefined && whitelistModule.has(node.parent.id)) {
mutableCount++;
}
mutableCount += node.children.map(child => dfs(child)).reduce((a, b) => a + b, 0);
if (mutableCount > 0 && node.isParent())
result.add(node.id);
return mutableCount;
};
dfs(this.root);
return result;
}
private inferMutableModule(moduleName: string): Edge[][] {
let inputs: string[] | undefined = undefined;
let listConstructNode: string | undefined = undefined;
const moduleNode = this.getNodeById(moduleName);
if (moduleNode === undefined) return [];
for (const node of moduleNode.children)
if (node.op === 'ListConstruct') {
inputs = this.backwardGraph.get(node.id);
listConstructNode = node.id;
break;
}
if (inputs === undefined || listConstructNode === undefined)
return [];
return inputs.map((input: string): Edge[] => {
const visitedNodes = new Set<string>();
const edgeSet: Edge[] = [];
const dfs = (node: string, backward: boolean) => {
const nodeData = this.getNodeById(node)!;
if (visitedNodes.has(node)) return;
visitedNodes.add(node);
if (nodeData.parent === undefined || !nodeData.parent.id.startsWith(moduleName)) {
// in another module now
return;
}
const g = backward ? this.backwardGraph : this.forwardGraph;
const glist = g.get(node);
if (glist !== undefined) {
glist
.forEach((to: string) => {
edgeSet.push(backward ?
this.getEdgeById(to, node)! :
this.getEdgeById(node, to)!);
dfs(to, backward);
});
}
};
edgeSet.push(this.getEdgeById(input, listConstructNode!)!);
dfs(input, true);
visitedNodes.clear();
dfs(listConstructNode!, false);
return edgeSet;
});
}
private inferMutableEdges(graphDataMutable: any): Map<string, Edge[][]> {
const result = new Map<string, Edge[][]>();
Object.entries(graphDataMutable).forEach(obj => {
const [key, paths] = obj;
const modules = (paths as any[]).map(path2module);
const moduleEdge = modules.map(this.inferMutableModule.bind(this));
const edges: Edge[][] = [];
for (let i = 0; ; ++i) {
if (moduleEdge.filter(me => i < me.length).length === 0)
break;
edges.push([]);
moduleEdge
.filter(me => i < me.length)
.forEach(me => edges[i].push(...me[i]));
}
result.set(key, edges);
});
return result;
}
private connectedEdges(node: string): Edge[] {
const result = this.node2edge.get(node);
if (result === undefined)
return [];
return result;
}
nodeSummary(node: NodeTs | string): NodeSummary | undefined {
if (typeof node === 'string') {
const nodeData = this.getNodeById(node);
if (nodeData === undefined) return undefined;
return this.nodeSummary(nodeData);
}
const descendants = node.descendants(false);
const descendantSet = new Set(descendants.map(node => node.id));
const inputs = new Set<string>();
const outputs = new Set<string>();
let domesticEdges = 0;
for (const edge of this.edges) {
const [source, target] = [edge.source.id, edge.target.id];
if (descendantSet.has(target) && !descendantSet.has(source))
inputs.add(source);
if (descendantSet.has(source) && !descendantSet.has(target))
outputs.add(target);
if (descendantSet.has(source) && descendantSet.has(target))
domesticEdges++;
}
return {
name: node.id,
nodeCount: descendants.length,
edgeCount: domesticEdges,
inputs: Array.from(inputs),
outputs: Array.from(outputs),
attributes: node.attributes,
op: node.op
}
}
weightFromMutables(mutable: any): Map<string, number> {
const elemWeight = new Map<string, number>();
Object.entries(mutable).forEach(entry => {
const key = entry[0];
const weights = entry[1] as number[];
this.mutableEdges.get(key)!.forEach((edges: any, i: number) => {
edges.forEach((edge: any) => {
if (elemWeight.has(edge.id)) {
elemWeight.set(edge.id, elemWeight.get(edge.id)! + weights[i]);
} else {
elemWeight.set(edge.id, weights[i]);
}
})
});
});
this.nodes.forEach(node => {
const edges = this.connectedEdges(node.id);
const relatedEdges = edges.filter(edge => elemWeight.has(edge.id));
if (relatedEdges.length > 0) {
if (relatedEdges.length < edges.length) {
elemWeight.set(node.id, 1.);
} else {
// all related edge
const nw = edges.map(edge => elemWeight.get(edge.id)!)
.reduce((a, b) => Math.max(a, b));
elemWeight.set(node.id, nw);
}
}
});
elemWeight.forEach((v, k) => elemWeight.set(k, Math.min(v, 1.)));
// set compound weight
const gatherWeightsFromChildren = (node: NodeTs): number | undefined => {
if (node.isParent()) {
const childrenWeights =
node.children.map(gatherWeightsFromChildren)
.filter(val => val !== undefined);
if (childrenWeights.length > 0) {
const nw = childrenWeights.reduce((a, b) => Math.max(a!, b!));
elemWeight.set(node.id, nw!);
return nw;
} else {
return undefined;
}
} else {
return elemWeight.get(node.id);
}
};
gatherWeightsFromChildren(this.root);
return elemWeight;
}
};
body {
margin: 0;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen',
'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue',
sans-serif;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
}
code {
font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New',
monospace;
}
import React from 'react';
import ReactDOM from 'react-dom';
import './index.css';
import App from './App';
import * as serviceWorker from './serviceWorker';
ReactDOM.render(<App />, document.getElementById('root'));
// If you want your app to work offline and load faster, you can change
// unregister() to register() below. Note this comes with some pitfalls.
// Learn more about service workers: https://bit.ly/CRA-PWA
serviceWorker.unregister();
/// <reference types="react-scripts" />
// This optional code is used to register a service worker.
// register() is not called by default.
// This lets the app load faster on subsequent visits in production, and gives
// it offline capabilities. However, it also means that developers (and users)
// will only see deployed updates on subsequent visits to a page, after all the
// existing tabs open on the page have been closed, since previously cached
// resources are updated in the background.
// To learn more about the benefits of this model and instructions on how to
// opt-in, read https://bit.ly/CRA-PWA
const isLocalhost = Boolean(
window.location.hostname === 'localhost' ||
// [::1] is the IPv6 localhost address.
window.location.hostname === '[::1]' ||
// 127.0.0.0/8 are considered localhost for IPv4.
window.location.hostname.match(
/^127(?:\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3}$/
)
);
type Config = {
onSuccess?: (registration: ServiceWorkerRegistration) => void;
onUpdate?: (registration: ServiceWorkerRegistration) => void;
};
export function register(config?: Config) {
if (process.env.NODE_ENV === 'production' && 'serviceWorker' in navigator) {
// The URL constructor is available in all browsers that support SW.
const publicUrl = new URL(
process.env.PUBLIC_URL,
window.location.href
);
if (publicUrl.origin !== window.location.origin) {
// Our service worker won't work if PUBLIC_URL is on a different origin
// from what our page is served on. This might happen if a CDN is used to
// serve assets; see https://github.com/facebook/create-react-app/issues/2374
return;
}
window.addEventListener('load', () => {
const swUrl = `${process.env.PUBLIC_URL}/service-worker.js`;
if (isLocalhost) {
// This is running on localhost. Let's check if a service worker still exists or not.
checkValidServiceWorker(swUrl, config);
// Add some additional logging to localhost, pointing developers to the
// service worker/PWA documentation.
navigator.serviceWorker.ready.then(() => {
console.log(
'This web app is being served cache-first by a service ' +
'worker. To learn more, visit https://bit.ly/CRA-PWA'
);
});
} else {
// Is not localhost. Just register service worker
registerValidSW(swUrl, config);
}
});
}
}
function registerValidSW(swUrl: string, config?: Config) {
navigator.serviceWorker
.register(swUrl)
.then(registration => {
registration.onupdatefound = () => {
const installingWorker = registration.installing;
if (installingWorker == null) {
return;
}
installingWorker.onstatechange = () => {
if (installingWorker.state === 'installed') {
if (navigator.serviceWorker.controller) {
// At this point, the updated precached content has been fetched,
// but the previous service worker will still serve the older
// content until all client tabs are closed.
console.log(
'New content is available and will be used when all ' +
'tabs for this page are closed. See https://bit.ly/CRA-PWA.'
);
// Execute callback
if (config && config.onUpdate) {
config.onUpdate(registration);
}
} else {
// At this point, everything has been precached.
// It's the perfect time to display a
// "Content is cached for offline use." message.
console.log('Content is cached for offline use.');
// Execute callback
if (config && config.onSuccess) {
config.onSuccess(registration);
}
}
}
};
};
})
.catch(error => {
console.error('Error during service worker registration:', error);
});
}
function checkValidServiceWorker(swUrl: string, config?: Config) {
// Check if the service worker can be found. If it can't reload the page.
fetch(swUrl, {
headers: { 'Service-Worker': 'script' }
})
.then(response => {
// Ensure service worker exists, and that we really are getting a JS file.
const contentType = response.headers.get('content-type');
if (
response.status === 404 ||
(contentType != null && contentType.indexOf('javascript') === -1)
) {
// No service worker found. Probably a different app. Reload the page.
navigator.serviceWorker.ready.then(registration => {
registration.unregister().then(() => {
window.location.reload();
});
});
} else {
// Service worker found. Proceed as normal.
registerValidSW(swUrl, config);
}
})
.catch(() => {
console.log(
'No internet connection found. App is running in offline mode.'
);
});
}
export function unregister() {
if ('serviceWorker' in navigator) {
navigator.serviceWorker.ready
.then(registration => {
registration.unregister();
})
.catch(error => {
console.error(error.message);
});
}
}
{
"compilerOptions": {
"target": "es5",
"lib": [
"dom",
"dom.iterable",
"esnext"
],
"typeRoots": [
"./types",
"./node_modules/@types"
],
"allowJs": true,
"skipLibCheck": true,
"esModuleInterop": true,
"allowSyntheticDefaultImports": true,
"strict": true,
"forceConsistentCasingInFileNames": true,
"module": "esnext",
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"noEmit": true,
"jsx": "react"
},
"include": [
"src"
]
}
declare module 'cytoscape-dagre';
This diff is collapsed.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
# This file is copied from PyTorch 1.4, with bug fixes.
# Likely to be removed in future.
import torch
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
from tensorboard.compat.proto.versions_pb2 import VersionDef
from torch.utils.tensorboard._pytorch_graph import GraphPy, CLASSTYPE_KIND, GETATTR_KIND, NodePyIO, NodePyOP
def parse(graph, trace, args=None, omit_useless_nodes=True):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
n_inputs = len(args)
scope = {}
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != CLASSTYPE_KIND:
nodes_py.append(NodePyIO(node, 'input'))
attr_to_scope = dict()
node_to_name = lambda d: str(d).split(":")[0].strip()
for node in graph.nodes():
if node.kind() == GETATTR_KIND:
attr_name = node.s('name')
node_name = node_to_name(node)
parent = node.input().node()
if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node
parent_attr_name = parent.s('name')
parent_scope = attr_to_scope[node_to_name(parent)]
attr_scope = parent_scope.split('/')[-1]
attr_to_scope[node_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name)
else:
attr_to_scope[node_name] = '__module.{}'.format(attr_name)
# We don't need classtype nodes; scope will provide this information
if node.output().type().kind() != CLASSTYPE_KIND:
node_py = NodePyOP(node)
node_py.scopeName = attr_to_scope[node_name]
nodes_py.append(node_py)
else:
nodes_py.append(NodePyOP(node))
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
node_py.debugName = "output.{}".format(i + 1)
node_py.inputs = [node.debugName()]
nodes_py.append(node_py)
def parse_traced_name(module_name):
prefix = 'TracedModule['
suffix = ']'
if module_name.startswith(prefix) and module_name.endswith(suffix):
module_name = module_name[len(prefix):-len(suffix)]
return module_name
alias_to_name = dict()
base_name = parse_traced_name(trace._name)
for name, module in trace.named_modules(prefix='__module'):
mod_name = parse_traced_name(module._name)
attr_name = name.split('.')[-1]
alias_to_name[name] = '{}[{}]'.format(mod_name, attr_name)
for node in nodes_py.nodes_op:
module_aliases = node.scopeName.split('/')[-1].split('.')
module_name = ''
for i, alias in enumerate(module_aliases):
if i == 0:
module_name = alias
node.scopeName = base_name
else:
module_name += '.' + alias
node.scopeName += '/' + (alias_to_name[module_name] if module_name in alias_to_name else alias)
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
def graph(model, args, verbose=False):
"""
This method processes a PyTorch model and produces a `GraphDef` proto
that can be logged to TensorBoard.
Args:
model (PyTorch module): The model to be parsed.
args (tuple): input tensor[s] for the model.
verbose (bool): Whether to print out verbose information while
processing.
"""
with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx?
try:
trace = torch.jit.trace(model, args)
graph = trace.graph
torch._C._jit_pass_inline(graph)
except RuntimeError as e:
print(e)
print('Error occurs, No graph saved')
raise e
if verbose:
print(graph)
list_of_nodes = parse(graph, trace, args)
# We are hardcoding that this was run on CPU even though it might have actually
# run on GPU. Note this is what is shown in TensorBoard and has no bearing
# on actual execution.
# TODO: See if we can extract GPU vs CPU information from the PyTorch model
# and pass it correctly to TensorBoard.
#
# Definition of StepStats and DeviceStepStats can be found at
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
# and
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
# The producer version has been reverse engineered from standard
# TensorBoard logged data.
......@@ -2,7 +2,9 @@
# Licensed under the MIT license.
import logging
from collections import defaultdict
import numpy as np
import torch
from nni.nas.pytorch.base_mutator import BaseMutator
......@@ -15,6 +17,7 @@ class Mutator(BaseMutator):
def __init__(self, model):
super().__init__(model)
self._cache = dict()
self._connect_all = False
def sample_search(self):
"""
......@@ -57,6 +60,74 @@ class Mutator(BaseMutator):
"""
return self.sample_final()
def status(self):
"""
Return current selection status of mutator.
Returns
-------
dict
A mapping from key of mutables to decisions. All weights (boolean type and float type)
are converted into real number values. Numpy arrays and tensors are converted into list.
"""
data = dict()
for k, v in self._cache.items():
if torch.is_tensor(v):
v = v.detach().cpu().numpy()
if isinstance(v, np.ndarray):
v = v.astype(np.float32).tolist()
data[k] = v
return data
def graph(self, inputs):
"""
Return model supernet graph.
Parameters
----------
inputs: tuple of tensor
Inputs that will be feeded into the network.
Returns
-------
dict
Containing ``node``, in Tensorboard GraphDef format.
Additional key ``mutable`` is a map from key to list of modules.
"""
if not torch.__version__.startswith("1.4"):
logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.")
from ._graph_utils import graph
from google.protobuf import json_format
# protobuf should be installed as long as tensorboard is installed
try:
self._connect_all = True
graph_def, _ = graph(self.model, inputs, verbose=False)
result = json_format.MessageToDict(graph_def)
finally:
self._connect_all = False
# `mutable` is to map the keys to a list of corresponding modules.
# A key can be linked to multiple modules, use `dedup=False` to find them all.
result["mutable"] = defaultdict(list)
for mutable in self.mutables.traverse(deduplicate=False):
# A module will be represent in the format of
# [{"type": "Net", "name": ""}, {"type": "Cell", "name": "cell1"}, {"type": "Conv2d": "name": "conv"}]
# which will be concatenated into Net/Cell[cell1]/Conv2d[conv] in frontend.
# This format is aligned with the scope name jit gives.
modules = mutable.name.split(".")
path = [
{"type": self.model.__class__.__name__, "name": ""}
]
m = self.model
for module in modules:
m = getattr(m, module)
path.append({
"type": m.__class__.__name__,
"name": module
})
result["mutable"][mutable.key].append(path)
return result
def on_forward_layer_choice(self, mutable, *inputs):
"""
On default, this method retrieves the decision obtained previously, and select certain operations.
......@@ -75,6 +146,11 @@ class Mutator(BaseMutator):
tuple of torch.Tensor and torch.Tensor
Output and mask.
"""
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction,
[op(*inputs) for op in mutable.choices]), \
torch.ones(mutable.length)
def _map_fn(op, *inputs):
return op(*inputs)
......@@ -101,6 +177,9 @@ class Mutator(BaseMutator):
tuple of torch.Tensor and torch.Tensor
Output and mask.
"""
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction, tensor_list), \
torch.ones(mutable.n_candidates)
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
......@@ -131,6 +210,13 @@ class Mutator(BaseMutator):
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
def _all_connect_tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
return tensor_list
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
return torch.stack(tensor_list).sum(0)
def _get_decision(self, mutable):
"""
By default, this method checks whether `mutable.key` is already in the decision cache,
......
......@@ -11,7 +11,7 @@ from .updater import update_searchspace, update_concurrency, update_duration, up
from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment, experiment_status,\
log_trial, experiment_clean, platform_clean, experiment_list, \
monitor_experiment, export_trials_data, trial_codegen, webui_url, \
get_config, log_stdout, log_stderr, search_space_auto_gen
get_config, log_stdout, log_stderr, search_space_auto_gen, webui_nas
from .package_management import package_install, package_show
from .constants import DEFAULT_REST_PORT
from .tensorboard_utils import start_tensorboard, stop_tensorboard
......@@ -158,6 +158,10 @@ def parse_args():
parser_webui_url = parser_webui_subparsers.add_parser('url', help='show the url of web ui')
parser_webui_url.add_argument('id', nargs='?', help='the id of experiment')
parser_webui_url.set_defaults(func=webui_url)
parser_webui_nas = parser_webui_subparsers.add_parser('nas', help='show nas ui')
parser_webui_nas.add_argument('--port', default=6060, type=int, help='port of nas ui')
parser_webui_nas.add_argument('--logdir', default='.', type=str, help='the logdir where nas ui will read data')
parser_webui_nas.set_defaults(func=webui_nas)
#parse config command
parser_config = subparsers.add_parser('config', help='get config information')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment