Unverified Commit 6568eaee authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #247 from microsoft/master

merge master
parents d90433da 1e2a2e29
######################
教程
######################
.. toctree::
:maxdepth: 2
安装<Tutorial/Installation>
实现 Trial<./TrialExample/Trials>
Tuner<tuners>
Assessor<assessors>
NAS (Beta) <nas>
模型压缩 (Beta) <model_compression>
特征工程 (Beta) <feature_engineering>
Web 界面<Tutorial/WebUI>
训练平台<training_services>
如何使用 Docker<Tutorial/HowToUseDocker>
高级功能<advanced>
如何调试<Tutorial/HowToDebug>
Windows 中使用 NNI<Tutorial/NniOnWindows>
\ No newline at end of file
# 加速掩码的模型
*此功能还处于预览版。*
## 介绍
剪枝算法通常都用权重掩码来模拟实际的剪枝。 掩码可以用来检查某个剪枝(或稀疏)算法的模型性能,但还没有真正加速。 模型加速才是模型剪枝的最终目标。因此提供了此工具,来帮助基于用户提供的掩码(掩码来自于剪枝算法),将已有模型转换成小模型。
有两种剪枝算法。 一种是细粒度的剪枝,不改变权重形状,和输入输出的张量。 稀疏内核会被用来加速细粒度剪枝的层。 另一类是粗粒度的剪枝(例如,通道),通常,权重形状,输入输出张量会有所改变。 要加速这类剪枝算法,不需要使用系数内核,只需要用更小的层来替换。 由于开源社区中对稀疏内核的支持还比较有限,当前仅支持粗粒度剪枝,会在将来再支持细粒度的剪枝算法。
## 设计和实现
为了加速模型,被剪枝的层应该被替换掉,要么为粗粒度掩码使用较小的层,要么用稀疏内核来替换细粒度的掩码。 粗粒度掩码通常会改变权重的形状,或输入输出张量,因此,应该通过形状推断,来检查是否其它未被剪枝的层由于形状变化而需要改变形状。 因此,在设计中,主要有两个步骤:第一,做形状推理,找出所有应该替换的模块;第二,替换模块。 第一步需要模型的拓扑(即连接),我们使用了 `jit.trace` 来获取 PyTorch 的模型图。
对于每个模块,要准备四个函数,三个用于形状推理,一个用于模块替换。 三个形状推理函数是:给定权重形状推断输入/输出形状,给定输入形状推断权重/输出形状,给定输出形状推断权重/输入形状。 模块替换功能返回一个较小的新创建的模块。
## 用法
```python
from nni.compression.speedup.torch import ModelSpeedup
# model: 要加速的模型
# dummy_input: 模型的示输入,传给 `jit.trace`
# masks_file: 剪枝算法创建的掩码文件
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
dummy_input = dummy_input.to(device)
start = time.time()
out = model(dummy_input)
print('elapsed time: ', time.time() - start)
```
完整示例参考[这里](https://github.com/microsoft/nni/tree/master/examples/model_compress/model_speedup.py)
注意:当前实现仅用于 torch 1.3.1 和 torchvision 0.4.2
## 局限性
由于每个模块需要 4 个函数用于形状推理和模块替换,因此工作量较大,当前仅实现了示例所需的函数。 如果要加速自己的模型,但当前不支持,欢迎贡献。
对于 PyTorch,仅提供了替换模块,如果是在 `forward` 中的函数,当前不支持。 一种解决方案是将函数变为 PyTorch 模块。
## 示例的加速结果
实验代码可在[这里](https://github.com/microsoft/nni/tree/master/examples/model_compress/model_speedup.py)找到。
### slim Pruner 示例
在一块 V100 GPU 上, 输入张量:`torch.randn(64, 3, 32, 32)`
| 次数 | 掩码时延 | 加速后的时延 |
| -- | ------- | -------- |
| 1 | 0.01197 | 0.005107 |
| 2 | 0.02019 | 0.008769 |
| 4 | 0.02733 | 0.014809 |
| 8 | 0.04310 | 0.027441 |
| 16 | 0.07731 | 0.05008 |
| 32 | 0.14464 | 0.10027 |
### fpgm Pruner 示例
在 CPU 上, 输入张量:`torch.randn(64, 1, 28, 28)`, 方差较大
| 次数 | 掩码时延 | 加速后的时延 |
| --- | ------- | -------- |
| 1 | 0.01383 | 0.01839 |
| 2 | 0.01167 | 0.003558 |
| 4 | 0.01636 | 0.01088 |
| 40 | 0.14412 | 0.08268 |
| 40 | 1.29385 | 0.14408 |
| 40 | 0.41035 | 0.46162 |
| 400 | 6.29020 | 5.82143 |
### l1filter Pruner 示例
在一块 V100 GPU 上, 输入张量:`torch.randn(64, 3, 32, 32)`
| 次数 | 掩码时延 | 加速后的时延 |
| -- | ------- | -------- |
| 1 | 0.01026 | 0.003677 |
| 2 | 0.01657 | 0.008161 |
| 4 | 0.02458 | 0.020018 |
| 8 | 0.03498 | 0.025504 |
| 16 | 0.06757 | 0.047523 |
| 32 | 0.10487 | 0.086442 |
### APoZ Pruner 示例
在一块 V100 GPU 上, 输入张量:`torch.randn(64, 3, 32, 32)`
| 次数 | 掩码时延 | 加速后的时延 |
| -- | ------- | -------- |
| 1 | 0.01389 | 0.004208 |
| 2 | 0.01628 | 0.008310 |
| 4 | 0.02521 | 0.014008 |
| 8 | 0.03386 | 0.023923 |
| 16 | 0.06042 | 0.046183 |
| 32 | 0.12421 | 0.087113 |
\ No newline at end of file
...@@ -334,9 +334,11 @@ class LocalTrainingService implements TrainingService { ...@@ -334,9 +334,11 @@ class LocalTrainingService implements TrainingService {
throw new Error(`Could not find stream in trial ${trialJob.id}`); throw new Error(`Could not find stream in trial ${trialJob.id}`);
} }
//Refer https://github.com/Juul/tail-stream/issues/20 //Refer https://github.com/Juul/tail-stream/issues/20
stream.end(0); setTimeout(() => {
stream.emit('end'); stream.end(0);
this.jobStreamMap.delete(trialJob.id); stream.emit('end');
this.jobStreamMap.delete(trialJob.id);
}, 5000);
} }
} }
if (trialJob.gpuIndices !== undefined && trialJob.gpuIndices.length > 0 && this.gpuScheduler !== undefined) { if (trialJob.gpuIndices !== undefined && trialJob.gpuIndices.length > 0 && this.gpuScheduler !== undefined) {
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import queue
import re
from collections import defaultdict
import torch
from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy
CLASSTYPE_KIND = 'ClassType'
GETATTR_KIND = 'prim::GetAttr'
_logger = logging.getLogger(__name__)
def build_module_graph(model, dummy_input):
return TorchModuleGraph(model, dummy_input)
def build_graph(model, dummy_input, verbose=False):
g = TorchProtoGraph(model, dummy_input, verbose)
return g.graph_def, g.stepstats
def parse_traced_name(module_name):
prefix = 'TracedModule['
suffix = ']'
if module_name.startswith(prefix) and module_name.endswith(suffix):
module_name = module_name[len(prefix):-len(suffix)]
return module_name
class TorchGraph:
"""
This class is to extract pytorch model topology graph by tracing
"""
def __init__(self, model, dummy_input):
"""
Parameters
----------
model : pytorch model
The model user wants to speed up
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in
"""
assert torch.__version__ >= '1.3.1'
self.bound_model = model
self._trace(model, dummy_input)
def _trace(self, model, dummy_input):
with torch.onnx.set_training(model, False):
self.trace = torch.jit.trace(model, dummy_input)
torch._C._jit_pass_inline(self.trace.graph)
class TorchProtoGraph(TorchGraph):
"""
Generates model graph for pytorch models in protobuf, this implementation is borrowed from pytorch v1.4.0,
and fixed following issues:
https://github.com/pytorch/pytorch/issues/33691
https://github.com/pytorch/pytorch/issues/33670
"""
def __init__(self, model, dummy_input, verbose=False):
super().__init__(model, dummy_input)
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
from tensorboard.compat.proto.versions_pb2 import VersionDef
list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input)
if verbose:
print(self.trace.graph)
self.stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
self.graph_def = GraphDef(node=list_of_nodes, versions=VersionDef(producer=22))
def parse(self, graph, trace, args=None, omit_useless_nodes=True):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if not node.uses(): # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != CLASSTYPE_KIND:
nodes_py.append(NodePyIO(node, 'input'))
attr_to_scope = dict()
node_to_name = lambda d: str(d).split(":")[0].strip()
for node in graph.nodes():
if node.kind() == GETATTR_KIND:
attr_name = node.s('name')
node_name = node_to_name(node)
parent = node.input().node()
if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node
parent_scope = attr_to_scope[node_to_name(parent)]
attr_scope = parent_scope.split('/')[-1]
attr_to_scope[node_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name)
else:
attr_to_scope[node_name] = '__module.{}'.format(attr_name)
# We don't need classtype nodes; scope will provide this information
if node.output().type().kind() != CLASSTYPE_KIND:
node_py = NodePyOP(node)
node_py.scopeName = attr_to_scope[node_name]
nodes_py.append(node_py)
else:
nodes_py.append(NodePyOP(node))
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
node_py.debugName = "output.{}".format(i + 1)
node_py.inputs = [node.debugName()]
nodes_py.append(node_py)
alias_to_name = dict()
base_name = parse_traced_name(trace._name)
for name, module in trace.named_modules(prefix='__module'):
mod_name = parse_traced_name(module._name)
attr_name = name.split('.')[-1]
alias_to_name[name] = '{}[{}]'.format(mod_name, attr_name)
for node in nodes_py.nodes_op:
module_aliases = node.scopeName.split('/')[-1].split('.')
module_name = ''
for i, alias in enumerate(module_aliases):
if i == 0:
module_name = alias
node.scopeName = base_name
else:
module_name += '.' + alias
node.scopeName += '/' + (alias_to_name[module_name] if module_name in alias_to_name else alias)
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
class NodePyGroup(NodePy):
"""
This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph,
there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
"""
def __init__(self, name, node_type, op_type, node_cpps, inputs=None, outputs=None):
"""
Parameters:
-----------
name: str
node name, such as `conv1`, `backbone.classifier`
node_type: str
`module` or `func`
op_type: str
operation type, such as `Conv2d`, `aten::view`
node_cpps: list of torch._C.Node
jit trace nodes which are included in this new node
inputs: list of str
All the inputs of this node, each element is debugName of one input
outputs: list of str
All the outputs of this node, each element is debugName of one output
"""
super(NodePyGroup, self).__init__(name, [])
self.node_cpps = node_cpps
self.name = name
self.op_type = op_type
self.type = node_type
self.nodes = []
self.auxiliary = None
self.add_nodes(node_cpps)
self.inputs = inputs
self.outputs = outputs
def add_nodes(self, node_cpps):
for node_cpp in node_cpps:
nodepy = NodePyOP(node_cpp)
nodepy.name = str(node_cpp).split(':')[0].strip().replace('%', '')
self.nodes.append(nodepy)
def sub_node_names(self):
return [x.name for x in self.nodes]
def __repr__(self):
return 'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'.format(
self.name, self.type, self.op_type, self.sub_node_names(), self.inputs, self.outputs, self.auxiliary
)
class TorchModuleGraph(TorchGraph):
"""
Generates model graph, each node is created from single or multiple jit trace nodes.
"""
def __init__(self, model, dummy_input):
super().__init__(model, dummy_input)
self.global_count = 0
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
trivial op which are label by ```prim::```, some of them are not such ops which is call
non-prim ops. This function is to merge neighbor prim ops to a non-prim op, to construct
a node.
Parameters
----------
node : trace graph node
The non-prim node to expand
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
Returns
-------
node
the expanded non-prim node
"""
# TODO: scope name could be empty
node_name = '.'.join([self._get_module_name(node.scopeName()), node.kind(), str(self.global_count)])
_logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
node_queue = queue.Queue()
node_queue.put(node)
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if predecessor_node.kind().startswith('prim::'):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
inputs.append(input_name)
else:
inputs.append(input_name)
for output in node.outputs():
outputs.append(output.debugName())
nodepy = NodePyGroup(node_name, 'func', op_type, node_group, inputs=inputs, outputs=outputs)
return nodepy
def _build_module_node_group(self, module_name, op_type, node_cpps, input_to_node, output_to_node):
graph = self.trace.graph
inputs, outputs = [], []
for n in node_cpps:
for i in n.inputs():
name = i.debugName()
if not name in output_to_node and i in graph.inputs():
inputs.append(name)
elif output_to_node[name] not in node_cpps:
inputs.append(name)
for o in n.outputs():
name = o.debugName()
if not name in input_to_node and o in graph.outputs():
outputs.append(name)
elif input_to_node[name] not in node_cpps:
outputs.append(name)
return NodePyGroup(module_name, 'module', op_type, node_cpps, inputs, outputs)
def _extract_shape_info(self, node):
"""
Extract the shape information of ```aten::view``` node
Parameters
----------
node : trace graph node
It should be ```aten::view``` node
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
t_input = None
for _input in node.inputs():
t_input = _input
break
t_output = node.output()
assert isinstance(t_input.type(), torch._C.TensorType)
assert isinstance(t_output.type(), torch._C.TensorType)
in_shape = t_input.type().sizes()
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
def _extract_leaf_modules(self):
"""
Extract leaf modules from the given graph. Leaf module means it does not have submodules.
To extract leaf modules because only leaf module can be replaced. And shape inference can
be done in leaf module level. Other shape inference is done in lower level i.e.,
operation level.
Returns
-------
list
a list of scope name of all the leaf modules
"""
def is_parent(name1, name2):
"""
check if name1 is parent node of name2, for example:
name1: aa.bb, name2: aa.bb.cc, return True
name1: aa.b, name2: aa.bb, return False
"""
parts1, parts2 = name1.split('.'), name2.split('.')
if len(parts1) >= len(parts2):
return False
for i in range(len(parts1)):
if parts2[i] != parts1[i]:
return False
return True
module_names = sorted([x[0] for x in self.trace.named_modules() if x[0]])
leaf_nodes = []
for i, name in enumerate(module_names):
if i + 1 >= len(module_names) or not is_parent(name, module_names[i + 1]):
leaf_nodes.append(name)
return leaf_nodes
def _get_module_name(self, scope_name):
"""
Retrieve module name from scope name.
Parameters:
-----------
scope_name: str
scope_name of a graph node, for example:
for pytorch 1.3.1: MyModel/BackboneModel[backbone]/Conv2d[conv2]
for pytorch 1.4.0: __module.backbone/__module.backbone.conv2
Returns:
-------
str
module name, such as backbone.conv2
"""
if torch.__version__ >= '1.4.0':
return scope_name.split('/')[-1].replace('__module.', '')
else:
return '.'.join(re.findall(r'\[(.*?)\]', scope_name))
def _build_index(self, nodes_op):
name_to_node = dict()
input_to_node = defaultdict(list)
output_to_node = dict()
for node in nodes_op:
name_to_node[node.name] = node
for _input in node.inputs:
input_to_node[_input].append(node)
for output in node.outputs:
assert not output in output_to_node, \
"One output cannot be generated by multiple nodes"
output_to_node[output] = node
return name_to_node, input_to_node, output_to_node
def _build_graph(self):
"""
Build graph using our defined format from jit trace.
There are basically three steps: first, construct necessary information (data structures),
second, extract all the modules to convert to node, Third, extract all functions to convert
to node.
Returns
-------
dict
use name to index nodes, key: node name, value: node
dict
use input (its name) to index nodes,
key: input, value: list of nodes that take this input
dict
use output (its name) to index nodes,
key: output, value: node that generates this output
"""
omit_useless_nodes = True
graph = self.trace.graph
_logger.debug(graph)
# build output mapping, from output debugName to its node
output_to_node = {x.debugName(): n for n in graph.nodes() for x in n.outputs()}
# build input mapping, from input debugName to its node
input_to_node = {x.debugName(): n for n in graph.nodes() for x in n.inputs()}
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = defaultdict(list)
# the mapping of function (non-module in forward) to nodes, key is scope name
func_to_nodes = defaultdict(list)
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if not node.uses(): # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != 'ClassType':
nodes_py.append(NodePyIO(node, 'input'))
self.leaf_modules = self._extract_leaf_modules()
module_to_type = {name: parse_traced_name(module._name) for name, module in self.trace.named_modules()}
# associate module name with their trace graph nodes
for node in graph.nodes():
module_name = self._get_module_name(node.scopeName())
if module_name in self.leaf_modules:
module_to_nodes[module_name].append(node)
else:
func_to_nodes[node.scopeName()].append(node)
# build node group for module
for module_name, node_cpps in module_to_nodes.items():
node_group = self._build_module_node_group(
module_name, module_to_type[module_name], node_cpps, input_to_node, output_to_node
)
_logger.debug('node_group: %s', node_group)
nodes_py.nodes_op.append(node_group)
# each scope_name may have multiple funcs, we split them and create node for each of them
# build node group for torch.nn.functional
for _, nodes in func_to_nodes.items():
# extract non prim:: nodes
non_prim_nodes = list()
for node in nodes:
if not node.kind().startswith('prim::'):
non_prim_nodes.append(node)
# for each non prim node, expand it
for node in non_prim_nodes:
node_group = self._expand_non_prim_node(node, nodes, input_to_node, output_to_node)
nodes_py.nodes_op.append(node_group)
# get shape infor for view (aten::view) func
if node_group.op_type in ['aten::view', 'aten::flatten']:
node_group.auxiliary = self._extract_shape_info(node)
for node in graph.outputs(): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
nodes_py.append(node_py)
self.nodes_py = nodes_py
# build index
return self._build_index(self.nodes_py.nodes_op)
def find_predecessors(self, module_name):
"""
Find predecessor node of the given node
Parameters
----------
module_name : str
The name of the node
Returns
-------
list
a list of nodes who are the given node's predecessor
"""
predecessors = []
for _input in self.name_to_node[module_name].inputs:
if not _input in self.output_to_node:
_logger.debug("cannot find node with %s as its output", _input)
else:
node_py = self.output_to_node[_input]
predecessors.append(node_py.name)
return predecessors
def find_successors(self, module_name):
"""
Find successor nodes of the given node
Parameters
----------
module_name : str
The name of the node
Returns
-------
list
a list of nodes who are the given node's successor
"""
successors = []
for output in self.name_to_node[module_name].outputs:
assert output in self.input_to_node, "No node with input {}".format(output)
nodes_py = self.input_to_node[output]
for node_py in nodes_py:
successors.append(node_py.name)
return successors
...@@ -12,6 +12,7 @@ replace_module = { ...@@ -12,6 +12,7 @@ replace_module = {
'Conv2d': lambda module, mask: replace_conv2d(module, mask), 'Conv2d': lambda module, mask: replace_conv2d(module, mask),
'MaxPool2d': lambda module, mask: no_replace(module, mask), 'MaxPool2d': lambda module, mask: no_replace(module, mask),
'AvgPool2d': lambda module, mask: no_replace(module, mask), 'AvgPool2d': lambda module, mask: no_replace(module, mask),
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask), 'ReLU': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask) 'Linear': lambda module, mask: replace_linear(module, mask)
} }
......
...@@ -2,9 +2,8 @@ ...@@ -2,9 +2,8 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import queue
import re
import torch import torch
from nni._graph_utils import build_module_graph
from .compress_modules import replace_module from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape
...@@ -33,38 +32,6 @@ def get_module_by_name(model, module_name): ...@@ -33,38 +32,6 @@ def get_module_by_name(model, module_name):
leaf_module = getattr(model, name_list[-1]) leaf_module = getattr(model, name_list[-1])
return model, leaf_module return model, leaf_module
class GNode:
"""
It is used to represent a node in model graph, in this graph a module is a node,
a function out of module (in ```forward``` function) could also be a node.
"""
def __init__(self, node_name, node_type, op_type, inputs, outputs, nodes):
"""
Parameters
----------
node_name : str
It is module name if the node is a module, it is ```scope_name.node_kind.seq``` if it is a func
node_type : str
It only has two options: `module` or `func`
op_type : str
The operation type of the module or func
inputs : list of str
All the inputs of this node, each element is debugName of one input
outputs : list of str
All the outputs of this node, each element is debugName of one output
nodes : list of node
All the trace graph nodes included in this module or func
"""
self.name = node_name
self.type = node_type
self.op_type = op_type
self.inputs = inputs
self.outputs = outputs
self.nodes = nodes
# store supplementary information for different op types
# for example, for ```view``` it stores the shape of its input and output
self.auxiliary = None
class ModelSpeedup: class ModelSpeedup:
""" """
This class is to speedup the model with provided weight mask This class is to speedup the model with provided weight mask
...@@ -84,347 +51,9 @@ class ModelSpeedup: ...@@ -84,347 +51,9 @@ class ModelSpeedup:
the device on which masks are placed, same to map_location in ```torch.load``` the device on which masks are placed, same to map_location in ```torch.load```
""" """
self.bound_model = model self.bound_model = model
self.dummy_input = dummy_input
self.masks = torch.load(masks_file, map_location) self.masks = torch.load(masks_file, map_location)
self.is_training = model.training
# to obtain forward graph, model should be in ```eval``` mode
if self.is_training:
model.eval()
self.trace_graph = torch.jit.trace(model, dummy_input)
if self.is_training:
model.train()
self.inferred_masks = dict() # key: module_name, value: ModuleMasks self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.g_nodes = list() self.torch_graph = build_module_graph(model, dummy_input)
self.global_count = 0
self.name_to_gnode, self.input_to_gnode, self.output_to_gnode = self._build_graph()
def _build_index_for_gnodes(self, g_nodes):
"""
Build indexes for quick search
Parameters
----------
g_nodes : list of GNode
All the g_node in processed model graph
Returns
-------
dict
use name to index g_nodes, key: node name, value: g_node
dict
use input (its name) to index g_nodes,
key: input, value: list of g_nodes that take this input
dict
use output (its name) to index g_nodes,
key: output, value: g_node that generates this output
"""
name_to_gnode = dict()
input_to_gnode = dict()
output_to_gnode = dict()
for node in g_nodes:
name_to_gnode[node.name] = node
for _input in node.inputs:
if _input in input_to_gnode:
input_to_gnode[_input].append(node)
else:
input_to_gnode[_input] = [node]
for output in node.outputs:
assert not output in output_to_gnode, \
"One output cannot be generated by multiple nodes"
output_to_gnode[output] = node
return name_to_gnode, input_to_gnode, output_to_gnode
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
trivial op which are label by ```prim::```, some of them are not such ops which is call
non-prim ops. This function is to merge neighbor prim ops to a non-prim op, to construct
a GNode.
Parameters
----------
node : trace graph node
The non-prim node to expand
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
Returns
-------
GNode
the expanded non-prim node in GNode format
"""
# TODO: scope name could be empty
node_name = '.'.join([node.scopeName(), node.kind(), str(self.global_count)])
_logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
node_queue = queue.Queue()
node_queue.put(node)
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if predecessor_node.kind().startswith('prim::'):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
inputs.append(input_name)
else:
inputs.append(input_name)
for output in node.outputs():
outputs.append(output.debugName())
g_node = GNode(node_name, 'func', op_type, inputs, outputs, node_group)
return g_node
def _extract_shape_info(self, node):
"""
Extract the shape information of ```aten::view``` node
Parameters
----------
node : trace graph node
It should be ```aten::view``` node
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
t_input = None
for _input in node.inputs():
t_input = _input
break
t_output = node.output()
assert isinstance(t_input.type(), torch._C.TensorType)
assert isinstance(t_output.type(), torch._C.TensorType)
in_shape = t_input.type().sizes()
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
def _extract_leaf_modules(self, graph):
"""
Extract leaf modules from the given graph. Leaf module means it does not have submodules.
To extract leaf modules because only leaf module can be replaced. And shape inference can
be done in leaf module level. Other shape inference is done in lower level i.e.,
operation level.
Parameters
----------
graph : jit trace graph
the graph generated from jit trace
Returns
-------
list
a list of scope name of all the leaf modules
"""
class SNode:
def __init__(self, name):
self.sname = name
self.childs = {}
root = None
for node in graph.nodes():
scope_name = node.scopeName()
if scope_name == '':
continue
segs = scope_name.split('/')
if root is None:
root = SNode(segs[0])
curr = root
for seg in segs[1:]:
if not seg in curr.childs:
curr.childs[seg] = SNode(seg)
curr = curr.childs[seg]
leaf_nodes = []
def traverse_tree(node, scope_name):
if scope_name == '':
sn = node.sname
else:
sn = scope_name + '/' + node.sname
if not node.childs:
if node.sname[-1] == ']':
leaf_nodes.append(sn)
else:
for key in node.childs:
traverse_tree(node.childs[key], sn)
traverse_tree(root, '')
return leaf_nodes
def _build_graph(self):
"""
Build graph using our defined format from jit trace.
There are basically three steps: first, construct necessary information (data structures),
second, extract all the modules to convert to GNode, Third, extract all functions to convert
to GNode.
Returns
-------
dict
use name to index g_nodes, key: node name, value: g_node
dict
use input (its name) to index g_nodes,
key: input, value: list of g_nodes that take this input
dict
use output (its name) to index g_nodes,
key: output, value: g_node that generates this output
"""
graph = self.trace_graph.graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
_logger.debug(graph)
# build output mapping, from output debugName to its node
output_to_node = dict()
# build input mapping, from input debugName to its node
input_to_node = dict()
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = dict()
# module name to its type
module_to_type = dict()
# the mapping of function (non-module in forward) to nodes, key is scope name
func_to_nodes = dict()
graph_inputs = list()
graph_outputs = list()
for _input in graph.inputs():
graph_inputs.append(_input.debugName())
for output in graph.outputs():
graph_outputs.append(output.debugName())
leaf_modules = self._extract_leaf_modules(graph)
_logger.debug(leaf_modules)
for node in graph.nodes():
# populate output_to_node and input_to_node
for output in node.outputs():
output_name = output.debugName()
output_to_node[output_name] = node
for _input in node.inputs():
input_name = _input.debugName()
input_to_node[input_name] = node
scope_name = node.scopeName() # example: scope_name, 'MyCell/Linear[linear]'
# if module_name is empty, it is not a module
if not scope_name in leaf_modules:
if scope_name == '':
continue
else:
if scope_name in func_to_nodes:
func_to_nodes[scope_name].append(node)
else:
func_to_nodes[scope_name] = [node]
else:
module_name_slices = re.findall(r'\[(.*?)\]', scope_name)
module_name = '.'.join(module_name_slices)
scope_slice = scope_name.split('/')[-1]
module_type = scope_slice.split('[')[0]
module_to_type[module_name] = module_type
if module_name in module_to_nodes:
module_to_nodes[module_name].append(node)
else:
module_to_nodes[module_name] = [node]
# construct GNode from module
for module_name, nodes in module_to_nodes.items():
inputs = set()
outputs = set()
for node in nodes:
for output in node.outputs():
outputs.add(output.debugName())
for _input in node.inputs():
inputs.add(_input.debugName())
m_inputs = list()
m_outputs = list()
for output in outputs:
# TODO: one input could be the input of multiple nodes
if not output in input_to_node and output in graph_outputs:
m_outputs.append(output)
elif not input_to_node[output] in nodes:
m_outputs.append(output)
for _input in inputs:
if not _input in output_to_node and _input in graph_inputs:
m_inputs.append(_input)
elif not output_to_node[_input] in nodes:
m_inputs.append(_input)
if module_name == '':
_logger.warning("module_name is empty string")
g_node = GNode(module_name, 'module', module_to_type[module_name], m_inputs, m_outputs, nodes)
self.g_nodes.append(g_node)
# each scope_name may have multiple funcs, we split them and create GNode for each of them
for scope_name, nodes in func_to_nodes.items():
# extract non prim:: nodes
non_prim_nodes = list()
for node in nodes:
if not node.kind().startswith('prim::'):
non_prim_nodes.append(node)
# for each non prim node, expand it has a GNode
for node in non_prim_nodes:
g_node = self._expand_non_prim_node(node, nodes, input_to_node, output_to_node)
self.g_nodes.append(g_node)
# get shape infor for view (aten::view) func
if g_node.op_type == 'aten::view':
g_node.auxiliary = self._extract_shape_info(node)
# build index for g_nodes
name_to_gnode, input_to_gnode, output_to_gnode = self._build_index_for_gnodes(self.g_nodes)
return name_to_gnode, input_to_gnode, output_to_gnode
def _find_predecessors(self, module_name):
"""
Find predecessor GNode of the given GNode
Parameters
----------
module_name : str
The name of the GNode
Returns
-------
list
a list of GNodes who are the given GNode's predecessor
"""
predecessors = []
for _input in self.name_to_gnode[module_name].inputs:
if not _input in self.output_to_gnode:
_logger.debug("cannot find gnode with %s as its output", _input)
else:
g_node = self.output_to_gnode[_input]
predecessors.append(g_node.name)
return predecessors
def _find_successors(self, module_name):
"""
Find successor GNodes of the given GNode
Parameters
----------
module_name : str
The name of the GNode
Returns
-------
list
a list of GNodes who are the given GNode's successor
"""
successors = []
for output in self.name_to_gnode[module_name].outputs:
assert output in self.input_to_gnode, "No gnode with input {}".format(output)
g_nodes = self.input_to_gnode[output]
for g_node in g_nodes:
successors.append(g_node.name)
return successors
def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None): def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None):
""" """
...@@ -441,13 +70,13 @@ class ModelSpeedup: ...@@ -441,13 +70,13 @@ class ModelSpeedup:
Parameters Parameters
---------- ----------
module_name : str module_name : str
The name of the GNode The name of the node
mask : tensor of mask or ModuleMasks mask : tensor of mask or ModuleMasks
Mask of the weights in this GNode (i.e., module) Mask of the weights in this node (i.e., module)
in_shape : ModuleMasks in_shape : ModuleMasks
Input shape of this GNode Input shape of this node
out_shape : ModuleMasks out_shape : ModuleMasks
Output shape of this GNode Output shape of this node
""" """
input_cmask = output_cmask = None input_cmask = output_cmask = None
if module_name in self.inferred_masks: if module_name in self.inferred_masks:
...@@ -456,7 +85,7 @@ class ModelSpeedup: ...@@ -456,7 +85,7 @@ class ModelSpeedup:
module_masks = ModuleMasks(module_name) module_masks = ModuleMasks(module_name)
self.inferred_masks[module_name] = module_masks self.inferred_masks[module_name] = module_masks
m_type = self.name_to_gnode[module_name].op_type m_type = self.torch_graph.name_to_node[module_name].op_type
_logger.debug("infer mask of module %s with op_type %s", module_name, m_type) _logger.debug("infer mask of module %s with op_type %s", module_name, m_type)
if mask is not None: if mask is not None:
_logger.debug("mask is not None") _logger.debug("mask is not None")
...@@ -471,10 +100,10 @@ class ModelSpeedup: ...@@ -471,10 +100,10 @@ class ModelSpeedup:
raise RuntimeError( raise RuntimeError(
"Has not supported infering output shape from input shape for module/function: `{}`, {}" "Has not supported infering output shape from input shape for module/function: `{}`, {}"
.format(m_type, module_name)) .format(m_type, module_name))
if m_type == 'aten::view': if m_type in ['aten::view', 'aten::flatten']:
output_cmask = infer_from_inshape[m_type](module_masks, output_cmask = infer_from_inshape[m_type](module_masks,
in_shape, in_shape,
self.name_to_gnode[module_name].auxiliary) self.torch_graph.name_to_node[module_name].auxiliary)
else: else:
output_cmask = infer_from_inshape[m_type](module_masks, in_shape) output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
if out_shape is not None: if out_shape is not None:
...@@ -486,11 +115,11 @@ class ModelSpeedup: ...@@ -486,11 +115,11 @@ class ModelSpeedup:
input_cmask = infer_from_outshape[m_type](module_masks, out_shape) input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if input_cmask: if input_cmask:
predecessors = self._find_predecessors(module_name) predecessors = self.torch_graph.find_predecessors(module_name)
for _module_name in predecessors: for _module_name in predecessors:
self.infer_module_mask(_module_name, out_shape=input_cmask) self.infer_module_mask(_module_name, out_shape=input_cmask)
if output_cmask: if output_cmask:
successors = self._find_successors(module_name) successors = self.torch_graph.find_successors(module_name)
for _module_name in successors: for _module_name in successors:
self.infer_module_mask(_module_name, in_shape=output_cmask) self.infer_module_mask(_module_name, in_shape=output_cmask)
...@@ -511,7 +140,7 @@ class ModelSpeedup: ...@@ -511,7 +140,7 @@ class ModelSpeedup:
is that ```func``` should be not required to be replaced. is that ```func``` should be not required to be replaced.
""" """
for module_name in self.inferred_masks: for module_name in self.inferred_masks:
g_node = self.name_to_gnode[module_name] g_node = self.torch_graph.name_to_node[module_name]
_logger.debug("replace %s, in %s type, with op_type %s", _logger.debug("replace %s, in %s type, with op_type %s",
module_name, g_node.type, g_node.op_type) module_name, g_node.type, g_node.op_type)
if g_node.type == 'module': if g_node.type == 'module':
...@@ -526,7 +155,7 @@ class ModelSpeedup: ...@@ -526,7 +155,7 @@ class ModelSpeedup:
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type", _logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
module_name, g_node.op_type) module_name, g_node.op_type)
else: else:
raise RuntimeError("Unsupported GNode type: {}".format(g_node.type)) raise RuntimeError("Unsupported node type: {}".format(g_node.type))
def speedup_model(self): def speedup_model(self):
""" """
...@@ -540,8 +169,3 @@ class ModelSpeedup: ...@@ -540,8 +169,3 @@ class ModelSpeedup:
_logger.info("replace compressed modules...") _logger.info("replace compressed modules...")
self.replace_compressed_modules() self.replace_compressed_modules()
_logger.info("speedup done") _logger.info("speedup done")
# resume the model mode to that before the model is speed up
if self.is_training:
self.bound_model.train()
else:
self.bound_model.eval()
\ No newline at end of file
...@@ -83,6 +83,9 @@ class CoarseMask: ...@@ -83,6 +83,9 @@ class CoarseMask:
cmask.mask_index[i]) cmask.mask_index[i])
return self.mask_index return self.mask_index
def __repr__(self):
return 'mask_index: {}'.format(self.mask_index)
class ModuleMasks: class ModuleMasks:
""" """
The masks of a module, including the masks for weights, inputs, output The masks of a module, including the masks for weights, inputs, output
...@@ -128,6 +131,11 @@ class ModuleMasks: ...@@ -128,6 +131,11 @@ class ModuleMasks:
""" """
self.output_mask = mask self.output_mask = mask
def __repr__(self):
return 'input_mask: {}, output_mask: {}, param_masks: {}'.format(
self.input_mask, self.output_mask, self.param_masks
)
""" """
Infer input and output shape of a module/function from its weight mask Infer input and output shape of a module/function from its weight mask
""" """
...@@ -147,8 +155,10 @@ infer_from_inshape = { ...@@ -147,8 +155,10 @@ infer_from_inshape = {
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'AvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'AvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::size': lambda module_masks, mask: size_inshape(module_masks, mask), 'aten::size': lambda module_masks, mask: size_inshape(module_masks, mask),
'aten::view': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape), 'aten::view': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape),
'aten::flatten': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape), # support only start_dim=1
'Linear': lambda module_masks, mask: linear_inshape(module_masks, mask), 'Linear': lambda module_masks, mask: linear_inshape(module_masks, mask),
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask) 'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask)
} }
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
# This file is copied from PyTorch 1.4, with bug fixes.
# Likely to be removed in future.
import torch
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
from tensorboard.compat.proto.versions_pb2 import VersionDef
from torch.utils.tensorboard._pytorch_graph import GraphPy, CLASSTYPE_KIND, GETATTR_KIND, NodePyIO, NodePyOP
def parse(graph, trace, args=None, omit_useless_nodes=True):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
n_inputs = len(args)
scope = {}
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != CLASSTYPE_KIND:
nodes_py.append(NodePyIO(node, 'input'))
attr_to_scope = dict()
node_to_name = lambda d: str(d).split(":")[0].strip()
for node in graph.nodes():
if node.kind() == GETATTR_KIND:
attr_name = node.s('name')
node_name = node_to_name(node)
parent = node.input().node()
if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node
parent_attr_name = parent.s('name')
parent_scope = attr_to_scope[node_to_name(parent)]
attr_scope = parent_scope.split('/')[-1]
attr_to_scope[node_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name)
else:
attr_to_scope[node_name] = '__module.{}'.format(attr_name)
# We don't need classtype nodes; scope will provide this information
if node.output().type().kind() != CLASSTYPE_KIND:
node_py = NodePyOP(node)
node_py.scopeName = attr_to_scope[node_name]
nodes_py.append(node_py)
else:
nodes_py.append(NodePyOP(node))
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
node_py.debugName = "output.{}".format(i + 1)
node_py.inputs = [node.debugName()]
nodes_py.append(node_py)
def parse_traced_name(module_name):
prefix = 'TracedModule['
suffix = ']'
if module_name.startswith(prefix) and module_name.endswith(suffix):
module_name = module_name[len(prefix):-len(suffix)]
return module_name
alias_to_name = dict()
base_name = parse_traced_name(trace._name)
for name, module in trace.named_modules(prefix='__module'):
mod_name = parse_traced_name(module._name)
attr_name = name.split('.')[-1]
alias_to_name[name] = '{}[{}]'.format(mod_name, attr_name)
for node in nodes_py.nodes_op:
module_aliases = node.scopeName.split('/')[-1].split('.')
module_name = ''
for i, alias in enumerate(module_aliases):
if i == 0:
module_name = alias
node.scopeName = base_name
else:
module_name += '.' + alias
node.scopeName += '/' + (alias_to_name[module_name] if module_name in alias_to_name else alias)
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
def graph(model, args, verbose=False):
"""
This method processes a PyTorch model and produces a `GraphDef` proto
that can be logged to TensorBoard.
Args:
model (PyTorch module): The model to be parsed.
args (tuple): input tensor[s] for the model.
verbose (bool): Whether to print out verbose information while
processing.
"""
with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx?
try:
trace = torch.jit.trace(model, args)
graph = trace.graph
torch._C._jit_pass_inline(graph)
except RuntimeError as e:
print(e)
print('Error occurs, No graph saved')
raise e
if verbose:
print(graph)
list_of_nodes = parse(graph, trace, args)
# We are hardcoding that this was run on CPU even though it might have actually
# run on GPU. Note this is what is shown in TensorBoard and has no bearing
# on actual execution.
# TODO: See if we can extract GPU vs CPU information from the PyTorch model
# and pass it correctly to TensorBoard.
#
# Definition of StepStats and DeviceStepStats can be found at
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
# and
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
# The producer version has been reverse engineered from standard
# TensorBoard logged data.
...@@ -107,12 +107,12 @@ class Mutator(BaseMutator): ...@@ -107,12 +107,12 @@ class Mutator(BaseMutator):
""" """
if not torch.__version__.startswith("1.4"): if not torch.__version__.startswith("1.4"):
logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.") logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.")
from ._graph_utils import graph from nni._graph_utils import build_graph
from google.protobuf import json_format from google.protobuf import json_format
# protobuf should be installed as long as tensorboard is installed # protobuf should be installed as long as tensorboard is installed
try: try:
self._connect_all = True self._connect_all = True
graph_def, _ = graph(self.model, inputs, verbose=False) graph_def, _ = build_graph(self.model, inputs, verbose=False)
result = json_format.MessageToDict(graph_def) result = json_format.MessageToDict(graph_def)
finally: finally:
self._connect_all = False self._connect_all = False
......
...@@ -55,7 +55,8 @@ class PdartsMutator(DartsMutator): ...@@ -55,7 +55,8 @@ class PdartsMutator(DartsMutator):
del module[index] del module[index]
assert len(module) <= len(choices), "Failed to remove dropped choices." assert len(module) <= len(choices), "Failed to remove dropped choices."
def sample_final(self): def export(self):
# Cannot rely on super().export() because P-DARTS has deleted some of the choices and has misaligned length.
results = super().sample_final() results = super().sample_final()
for mutable in self.mutables: for mutable in self.mutables:
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
......
node {
name: "input/input"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "myLinear/Linear[l]/22"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "myLinear/Linear[l]/bias/17"
op: "prim::GetAttr"
input: "myLinear/Linear[l]/weight/14"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "myLinear/Linear[l]/weight/18"
op: "prim::GetAttr"
input: "myLinear/Linear[l]/weight/14"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "myLinear/Linear[l]/19"
op: "aten::t"
input: "myLinear/Linear[l]/weight/18"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 3
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "myLinear/Linear[l]/20"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/21"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/22"
op: "aten::addmm"
input: "myLinear/Linear[l]/bias/17"
input: "input/input"
input: "myLinear/Linear[l]/19"
input: "myLinear/Linear[l]/20"
input: "myLinear/Linear[l]/21"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}
node {
name: "input/input.1"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "input/input.1"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "MyModule/Linear[weight]/bias/49"
op: "prim::GetAttr"
input: "MyModule/Linear[weight]/weight/35"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/Linear[weight]/weight/50"
op: "prim::GetAttr"
input: "MyModule/Linear[weight]/weight/35"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/Linear[weight]/51"
op: "aten::t"
input: "MyModule/Linear[weight]/weight/50"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/Linear[weight]/52"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[weight]/53"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[weight]/54"
op: "aten::addmm"
input: "MyModule/Linear[weight]/bias/49"
input: "input/input.1"
input: "MyModule/Linear[weight]/51"
input: "MyModule/Linear[weight]/52"
input: "MyModule/Linear[weight]/53"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/Linear[bias]/bias/55"
op: "prim::GetAttr"
input: "MyModule/Linear[bias]/weight/38"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/Linear[bias]/weight/56"
op: "prim::GetAttr"
input: "MyModule/Linear[bias]/weight/38"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/Linear[bias]/57"
op: "aten::t"
input: "MyModule/Linear[bias]/weight/56"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/Linear[bias]/58"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[bias]/59"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[bias]/60"
op: "aten::addmm"
input: "MyModule/Linear[bias]/bias/55"
input: "input/input.1"
input: "MyModule/Linear[bias]/57"
input: "MyModule/Linear[bias]/58"
input: "MyModule/Linear[bias]/59"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/23"
op: "prim::ListConstruct"
input: "MyModule/Linear[weight]/54"
input: "MyModule/Linear[bias]/60"
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/24"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/input"
op: "aten::cat"
input: "MyModule/23"
input: "MyModule/24"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 6
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/61"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}
node {
name: "input/input.1"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "MyModule/ModuleList[module]/Linear[1]/46"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 1
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/bias/35"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[0]/weight/26"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/weight/36"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[0]/weight/26"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/37"
op: "aten::t"
input: "MyModule/ModuleList[module]/Linear[0]/weight/36"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/38"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/39"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/input"
op: "aten::addmm"
input: "MyModule/ModuleList[module]/Linear[0]/bias/35"
input: "input/input.1"
input: "MyModule/ModuleList[module]/Linear[0]/37"
input: "MyModule/ModuleList[module]/Linear[0]/38"
input: "MyModule/ModuleList[module]/Linear[0]/39"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/bias/41"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[1]/weight/30"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/weight/42"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[1]/weight/30"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/43"
op: "aten::t"
input: "MyModule/ModuleList[module]/Linear[1]/weight/42"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 3
}
dim {
size: 1
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/44"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/45"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/46"
op: "aten::addmm"
input: "MyModule/ModuleList[module]/Linear[1]/bias/41"
input: "MyModule/ModuleList[module]/Linear[0]/input"
input: "MyModule/ModuleList[module]/Linear[1]/43"
input: "MyModule/ModuleList[module]/Linear[1]/44"
input: "MyModule/ModuleList[module]/Linear[1]/45"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 1
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
import os
import math
import uuid
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorboard.compat.proto.graph_pb2 import GraphDef
from google.protobuf import text_format
import unittest
from unittest import TestCase, main
from nni._graph_utils import build_module_graph, build_graph
class BackboneModel1(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1, 1)
def forward(self, x):
return self.conv1(x)
class BackboneModel2(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class BigModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.backbone1 = BackboneModel1()
self.backbone2 = BackboneModel2()
self.fc3 = nn.Linear(10, 2)
def forward(self, x):
x = self.backbone1(x)
x = self.backbone2(x)
x = self.fc3(x)
return x
class GraphUtilsTestCase(TestCase):
def test_build_module_graph(self):
big_model = BigModel()
g = build_module_graph(big_model, torch.randn(2, 1, 28, 28))
print(g.name_to_node.keys())
leaf_modules = set([
'backbone1.conv1', 'backbone2.bn1', 'backbone2.bn2', 'backbone2.conv1',
'backbone2.conv2', 'backbone2.fc1', 'backbone2.fc2', 'fc3'
])
assert set(g.leaf_modules) == leaf_modules
assert not leaf_modules - set(g.name_to_node.keys())
assert g.find_successors('backbone2.conv1') == ['backbone2.bn1']
assert g.find_successors('backbone2.conv2') == ['backbone2.bn2']
assert g.find_predecessors('backbone2.bn1') == ['backbone2.conv1']
assert g.find_predecessors('backbone2.bn2') == ['backbone2.conv2']
def _test_graph(self, model, dummy_input, expected_file):
actual_proto, _ = build_graph(model, dummy_input)
assert os.path.exists(expected_file), expected_file
with open(expected_file, "r") as f:
expected_str = f.read()
expected_proto = GraphDef()
text_format.Parse(expected_str, expected_proto)
self.assertEquals(len(expected_proto.node), len(actual_proto.node))
for i in range(len(expected_proto.node)):
expected_node = expected_proto.node[i]
actual_node = actual_proto.node[i]
self.assertEquals(expected_node.name, actual_node.name)
self.assertEquals(expected_node.op, actual_node.op)
self.assertEquals(expected_node.input, actual_node.input)
self.assertEquals(expected_node.device, actual_node.device)
self.assertEquals(
sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
@unittest.skipIf(torch.__version__ < "1.4.0", "not supported")
def test_graph_module1(self):
dummy_input = (torch.zeros(1, 3),)
class myLinear(torch.nn.Module):
def __init__(self):
super(myLinear, self).__init__()
self.l = torch.nn.Linear(3, 5)
def forward(self, x):
return self.l(x)
self._test_graph(
myLinear(),
dummy_input,
os.path.join(os.path.dirname(__file__), "expect", "test_graph_module1.expect")
)
@unittest.skipIf(torch.__version__ < "1.4.0", "not supported")
def test_graph_module2(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Linear(5, 3)
self.bias = nn.Linear(5, 3)
self.module = nn.Linear(6, 1)
def forward(self, x):
tensors = [self.weight(x), self.bias(x)]
self.module(torch.cat(tensors, dim=1))
return x
self._test_graph(
MyModule(),
torch.randn(4, 5),
os.path.join(os.path.dirname(__file__), "expect", "test_graph_module2.expect")
)
@unittest.skipIf(torch.__version__ < "1.4.0", "not supported")
def test_graph_module3(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.module = nn.ModuleList([
nn.Linear(5, 3),
nn.Linear(3, 1)
])
def forward(self, x):
x = self.module[0](x)
x = self.module[1](x)
return x
self._test_graph(
MyModule(),
torch.randn(4, 5),
os.path.join(os.path.dirname(__file__), "expect", "test_graph_module3.expect")
)
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.vgg import vgg16
from torchvision.models.resnet import resnet18
from unittest import TestCase, main
from nni.compression.torch import L1FilterPruner
from nni.compression.speedup.torch import ModelSpeedup
class BackboneModel1(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1, 1)
def forward(self, x):
return self.conv1(x)
class BackboneModel2(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class BigModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.backbone1 = BackboneModel1()
self.backbone2 = BackboneModel2()
self.fc3 = nn.Sequential(
nn.Linear(10, 10),
nn.BatchNorm1d(10),
nn.ReLU(inplace=True),
nn.Linear(10, 2)
)
def forward(self, x):
x = self.backbone1(x)
x = self.backbone2(x)
x = self.fc3(x)
return x
SPARSITY = 0.5
def prune_model_l1(model):
config_list = [{
'sparsity': SPARSITY,
'op_types': ['Conv2d']
}]
pruner = L1FilterPruner(model, config_list)
pruner.compress()
pruner.export_model(model_path='./11_model.pth', mask_path='./l1_mask.pth')
class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self):
prune_model_l1(vgg16())
model = vgg16()
model.train()
ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), './l1_mask.pth')
ms.speedup_model()
orig_model = vgg16()
assert model.training
assert model.features[2].out_channels == int(orig_model.features[2].out_channels * SPARSITY)
assert model.classifier[0].in_features == int(orig_model.classifier[0].in_features * SPARSITY)
#def test_speedup_resnet(self):
#TODO support resnet
#model = resnet18()
def test_speedup_bigmodel(self):
prune_model_l1(BigModel())
model = BigModel()
model.train()
ms = ModelSpeedup(model, torch.randn(2, 1, 28, 28), './l1_mask.pth')
ms.speedup_model()
orig_model = BigModel()
assert model.training
assert model.backbone2.conv1.out_channels == int(orig_model.backbone2.conv1.out_channels * SPARSITY)
assert model.backbone2.conv2.in_channels == int(orig_model.backbone2.conv2.in_channels * SPARSITY)
assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY)
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)
def tearDown(self):
os.remove('./11_model.pth')
os.remove('./l1_mask.pth')
if __name__ == '__main__':
main()
...@@ -21,7 +21,7 @@ class App extends React.Component<{}, AppState> { ...@@ -21,7 +21,7 @@ class App extends React.Component<{}, AppState> {
private timerId!: number | undefined; private timerId!: number | undefined;
private dataFormatimer!: number; private dataFormatimer!: number;
private firstLoad: boolean = false; // when click refresh selector options private firstLoad: boolean = false; // when click refresh selector options
constructor(props: {}) { constructor(props: {}) {
super(props); super(props);
this.state = { this.state = {
...@@ -49,8 +49,8 @@ class App extends React.Component<{}, AppState> { ...@@ -49,8 +49,8 @@ class App extends React.Component<{}, AppState> {
} }
getFinalDataFormat = (): void => { getFinalDataFormat = (): void => {
for(let i = 0; this.state.isillegalFinal === false; i++){ for (let i = 0; this.state.isillegalFinal === false; i++) {
if(TRIALS.succeededTrials()[0] !== undefined && TRIALS.succeededTrials()[0].final !== undefined){ if (TRIALS.succeededTrials()[0] !== undefined && TRIALS.succeededTrials()[0].final !== undefined) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const oneSucceedTrial = JSON.parse(JSON.parse(TRIALS.succeededTrials()[0].final!.data)); const oneSucceedTrial = JSON.parse(JSON.parse(TRIALS.succeededTrials()[0].final!.data));
if (typeof oneSucceedTrial === 'number' || oneSucceedTrial.hasOwnProperty('default')) { if (typeof oneSucceedTrial === 'number' || oneSucceedTrial.hasOwnProperty('default')) {
...@@ -71,14 +71,14 @@ class App extends React.Component<{}, AppState> { ...@@ -71,14 +71,14 @@ class App extends React.Component<{}, AppState> {
} }
changeInterval = (interval: number): void => { changeInterval = (interval: number): void => {
window.clearTimeout(this.timerId); window.clearTimeout(this.timerId);
if (interval === 0) { if (interval === 0) {
return; return;
} }
// setState will trigger page refresh at once. // setState will trigger page refresh at once.
// setState is asyc, interval not update to (this.state.interval) at once. // setState is asyc, interval not update to (this.state.interval) at once.
this.setState({interval}, () => { this.setState({ interval }, () => {
this.firstLoad = true; this.firstLoad = true;
this.refresh(); this.refresh();
}); });
...@@ -96,7 +96,7 @@ class App extends React.Component<{}, AppState> { ...@@ -96,7 +96,7 @@ class App extends React.Component<{}, AppState> {
// overview best trial module // overview best trial module
changeEntries = (entries: string): void => { changeEntries = (entries: string): void => {
this.setState({bestTrialEntries: entries}); this.setState({ bestTrialEntries: entries });
} }
render(): React.ReactNode { render(): React.ReactNode {
...@@ -106,15 +106,25 @@ class App extends React.Component<{}, AppState> { ...@@ -106,15 +106,25 @@ class App extends React.Component<{}, AppState> {
if (experimentUpdateBroadcast === 0 || trialsUpdateBroadcast === 0) { if (experimentUpdateBroadcast === 0 || trialsUpdateBroadcast === 0) {
return null; // TODO: render a loading page return null; // TODO: render a loading page
} }
const errorList = [
{ errorWhere: TRIALS.jobListError(), errorMessage: TRIALS.getJobErrorMessage() },
{ errorWhere: EXPERIMENT.experimentError(), errorMessage: EXPERIMENT.getExperimentMessage() },
{ errorWhere: EXPERIMENT.statusError(), errorMessage: EXPERIMENT.getStatusMessage() },
{ errorWhere: TRIALS.MetricDataError(), errorMessage: TRIALS.getMetricDataErrorMessage() },
{ errorWhere: TRIALS.latestMetricDataError(), errorMessage: TRIALS.getLatestMetricDataErrorMessage() },
{ errorWhere: TRIALS.metricDataRangeError(), errorMessage: TRIALS.metricDataRangeErrorMessage() }
];
const reactPropsChildren = React.Children.map(this.props.children, child => const reactPropsChildren = React.Children.map(this.props.children, child =>
React.cloneElement( React.cloneElement(
child as React.ReactElement<any>, { child as React.ReactElement<any>, {
interval, interval,
columnList, changeColumn: this.changeColumn, columnList, changeColumn: this.changeColumn,
experimentUpdateBroadcast, experimentUpdateBroadcast,
trialsUpdateBroadcast, trialsUpdateBroadcast,
metricGraphMode, changeMetricGraphMode: this.changeMetricGraphMode, metricGraphMode, changeMetricGraphMode: this.changeMetricGraphMode,
bestTrialEntries, changeEntries: this.changeEntries bestTrialEntries, changeEntries: this.changeEntries
}) })
); );
...@@ -127,6 +137,16 @@ class App extends React.Component<{}, AppState> { ...@@ -127,6 +137,16 @@ class App extends React.Component<{}, AppState> {
</div> </div>
<Stack className="contentBox"> <Stack className="contentBox">
<Stack className="content"> <Stack className="content">
{/* if api has error field, show error message */}
{
errorList.map((item, key) => {
return (
item.errorWhere && <div key={key} className="warning">
<MessageInfo info={item.errorMessage} typeInfo="error" />
</div>
);
})
}
{isillegalFinal && <div className="warning"> {isillegalFinal && <div className="warning">
<MessageInfo info={expWarningMessage} typeInfo="warning" /> <MessageInfo info={expWarningMessage} typeInfo="warning" />
</div>} </div>}
...@@ -149,18 +169,20 @@ class App extends React.Component<{}, AppState> { ...@@ -149,18 +169,20 @@ class App extends React.Component<{}, AppState> {
if (trialsUpdated) { if (trialsUpdated) {
this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 })); this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 }));
} }
} else { } else {
this.firstLoad = false; this.firstLoad = false;
} }
if (['DONE', 'ERROR', 'STOPPED'].includes(EXPERIMENT.status)) { // experiment status and /trial-jobs api's status could decide website update
if (['DONE', 'ERROR', 'STOPPED'].includes(EXPERIMENT.status) || TRIALS.jobListError()) {
// experiment finished, refresh once more to ensure consistency // experiment finished, refresh once more to ensure consistency
this.setState({ interval: 0 }); this.setState({ interval: 0 });
this.lastRefresh(); this.lastRefresh();
return; return;
} }
this.timerId = window.setTimeout(this.refresh, this.state.interval * 1000); this.timerId = window.setTimeout(this.refresh, this.state.interval * 1000);
} }
......
...@@ -18,7 +18,7 @@ class MessageInfo extends React.Component<MessageInfoProps, {}> { ...@@ -18,7 +18,7 @@ class MessageInfo extends React.Component<MessageInfoProps, {}> {
return ( return (
<MessageBar <MessageBar
messageBarType={MessageBarType[typeInfo]} messageBarType={MessageBarType[typeInfo]}
isMultiline={false} isMultiline={true}
className={className} className={className}
> >
{info} {info}
......
...@@ -22,7 +22,7 @@ class ProgressBar extends React.Component<ProItemProps, {}> { ...@@ -22,7 +22,7 @@ class ProgressBar extends React.Component<ProItemProps, {}> {
<div> <div>
<Stack horizontal className={`probar ${bgclass}`}> <Stack horizontal className={`probar ${bgclass}`}>
<div className="name">{who}</div> <div className="name">{who}</div>
<div className="showProgress" style={{ width: '80%' }}> <div className="showProgress" style={{ width: '78%' }}>
<ProgressIndicator <ProgressIndicator
barHeight={30} barHeight={30}
percentComplete={percent} percentComplete={percent}
...@@ -32,7 +32,7 @@ class ProgressBar extends React.Component<ProItemProps, {}> { ...@@ -32,7 +32,7 @@ class ProgressBar extends React.Component<ProItemProps, {}> {
<StackItem className="right" grow={70}>{maxString}</StackItem> <StackItem className="right" grow={70}>{maxString}</StackItem>
</Stack> </Stack>
</div> </div>
<div className="description" style={{ width: '20%' }}>{description}</div> <div className="description" style={{ width: '22%' }}>{description}</div>
</Stack> </Stack>
<br /> <br />
</div> </div>
......
import * as React from 'react'; import * as React from 'react';
import ReactEcharts from 'echarts-for-react'; import ReactEcharts from 'echarts-for-react';
import { filterByStatus } from '../../static/function'; import { filterByStatus } from '../../static/function';
import { EXPERIMENT } from '../../static/datamodel';
import { Stack, PrimaryButton, Dropdown, IDropdownOption, } from 'office-ui-fabric-react'; // eslint-disable-line no-unused-vars import { Stack, PrimaryButton, Dropdown, IDropdownOption, } from 'office-ui-fabric-react'; // eslint-disable-line no-unused-vars
import { ParaObj, Dimobj, TableObj } from '../../static/interface'; // eslint-disable-line no-unused-vars import { ParaObj, Dimobj, TableObj } from '../../static/interface'; // eslint-disable-line no-unused-vars
import 'echarts/lib/chart/parallel'; import 'echarts/lib/chart/parallel';
...@@ -98,7 +99,13 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -98,7 +99,13 @@ class Para extends React.Component<ParaProps, ParaState> {
// according acc to sort ydata // sort to find top percent dataset // according acc to sort ydata // sort to find top percent dataset
if (paraYdata.length !== 0) { if (paraYdata.length !== 0) {
const len = paraYdata[0].length - 1; const len = paraYdata[0].length - 1;
paraYdata.sort((a, b) => b[len] - a[len]); // show top trials
if (EXPERIMENT.optimizeMode === 'minimize') {
paraYdata.sort((a, b) => a[len] - b[len]);
}
if (EXPERIMENT.optimizeMode === 'maximize') {
paraYdata.sort((a, b) => b[len] - a[len]);
}
} }
const paraData = { const paraData = {
parallelAxis: parallelAxis, parallelAxis: parallelAxis,
......
...@@ -3,6 +3,19 @@ import axios from 'axios'; ...@@ -3,6 +3,19 @@ import axios from 'axios';
import { MANAGER_IP } from './const'; import { MANAGER_IP } from './const';
import { MetricDataRecord, FinalType, TableObj } from './interface'; import { MetricDataRecord, FinalType, TableObj } from './interface';
async function requestAxios(url: string) {
const response = await axios.get(url);
if (response.status === 200) {
if (response.data.error !== undefined) {
throw new Error(`API ${url} ${response.data.error}`);
} else {
return response.data as any;
}
} else {
throw new Error(`API ${url} ${response.status} error`);
}
}
const convertTime = (num: number): string => { const convertTime = (num: number): string => {
if (num <= 0) { if (num <= 0) {
return '0'; return '0';
...@@ -219,5 +232,5 @@ export { ...@@ -219,5 +232,5 @@ export {
convertTime, convertDuration, getFinalResult, getFinal, downFile, convertTime, convertDuration, getFinalResult, getFinal, downFile,
intermediateGraphOption, killJob, filterByStatus, filterDuration, intermediateGraphOption, killJob, filterByStatus, filterDuration,
formatAccuracy, formatTimestamp, metricAccuracy, parseMetrics, formatAccuracy, formatTimestamp, metricAccuracy, parseMetrics,
isArrayType isArrayType, requestAxios
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment