Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
6568eaee
Unverified
Commit
6568eaee
authored
May 18, 2020
by
SparkSnail
Committed by
GitHub
May 18, 2020
Browse files
Merge pull request #247 from microsoft/master
merge master
parents
d90433da
1e2a2e29
Changes
45
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1561 additions
and
665 deletions
+1561
-665
docs/zh_CN/tutorials.rst
docs/zh_CN/tutorials.rst
+0
-20
examples/model_compress/speedup_zh_CN.md
examples/model_compress/speedup_zh_CN.md
+0
-96
src/nni_manager/training_service/local/localTrainingService.ts
...ni_manager/training_service/local/localTrainingService.ts
+5
-3
src/sdk/pynni/nni/_graph_utils.py
src/sdk/pynni/nni/_graph_utils.py
+490
-0
src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py
...k/pynni/nni/compression/speedup/torch/compress_modules.py
+1
-0
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
+13
-389
src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py
src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py
+10
-0
src/sdk/pynni/nni/nas/pytorch/_graph_utils.py
src/sdk/pynni/nni/nas/pytorch/_graph_utils.py
+0
-134
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+2
-2
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
+2
-1
src/sdk/pynni/tests/expect/test_graph_module1.expect
src/sdk/pynni/tests/expect/test_graph_module1.expect
+152
-0
src/sdk/pynni/tests/expect/test_graph_module2.expect
src/sdk/pynni/tests/expect/test_graph_module2.expect
+309
-0
src/sdk/pynni/tests/expect/test_graph_module3.expect
src/sdk/pynni/tests/expect/test_graph_module3.expect
+250
-0
src/sdk/pynni/tests/test_graph_utils.py
src/sdk/pynni/tests/test_graph_utils.py
+158
-0
src/sdk/pynni/tests/test_model_speedup.py
src/sdk/pynni/tests/test_model_speedup.py
+107
-0
src/webui/src/App.tsx
src/webui/src/App.tsx
+37
-15
src/webui/src/components/Modals/MessageInfo.tsx
src/webui/src/components/Modals/MessageInfo.tsx
+1
-1
src/webui/src/components/overview/ProgressItem.tsx
src/webui/src/components/overview/ProgressItem.tsx
+2
-2
src/webui/src/components/trial-detail/Para.tsx
src/webui/src/components/trial-detail/Para.tsx
+8
-1
src/webui/src/static/function.ts
src/webui/src/static/function.ts
+14
-1
No files found.
docs/zh_CN/tutorials.rst
deleted
100644 → 0
View file @
d90433da
######################
教程
######################
.. toctree::
:maxdepth: 2
安装<Tutorial/Installation>
实现 Trial<./TrialExample/Trials>
Tuner<tuners>
Assessor<assessors>
NAS (Beta) <nas>
模型压缩 (Beta) <model_compression>
特征工程 (Beta) <feature_engineering>
Web 界面<Tutorial/WebUI>
训练平台<training_services>
如何使用 Docker<Tutorial/HowToUseDocker>
高级功能<advanced>
如何调试<Tutorial/HowToDebug>
Windows 中使用 NNI<Tutorial/NniOnWindows>
\ No newline at end of file
examples/model_compress/speedup_zh_CN.md
deleted
100644 → 0
View file @
d90433da
# 加速掩码的模型
*此功能还处于预览版。*
## 介绍
剪枝算法通常都用权重掩码来模拟实际的剪枝。 掩码可以用来检查某个剪枝(或稀疏)算法的模型性能,但还没有真正加速。 模型加速才是模型剪枝的最终目标。因此提供了此工具,来帮助基于用户提供的掩码(掩码来自于剪枝算法),将已有模型转换成小模型。
有两种剪枝算法。 一种是细粒度的剪枝,不改变权重形状,和输入输出的张量。 稀疏内核会被用来加速细粒度剪枝的层。 另一类是粗粒度的剪枝(例如,通道),通常,权重形状,输入输出张量会有所改变。 要加速这类剪枝算法,不需要使用系数内核,只需要用更小的层来替换。 由于开源社区中对稀疏内核的支持还比较有限,当前仅支持粗粒度剪枝,会在将来再支持细粒度的剪枝算法。
## 设计和实现
为了加速模型,被剪枝的层应该被替换掉,要么为粗粒度掩码使用较小的层,要么用稀疏内核来替换细粒度的掩码。 粗粒度掩码通常会改变权重的形状,或输入输出张量,因此,应该通过形状推断,来检查是否其它未被剪枝的层由于形状变化而需要改变形状。 因此,在设计中,主要有两个步骤:第一,做形状推理,找出所有应该替换的模块;第二,替换模块。 第一步需要模型的拓扑(即连接),我们使用了
`jit.trace`
来获取 PyTorch 的模型图。
对于每个模块,要准备四个函数,三个用于形状推理,一个用于模块替换。 三个形状推理函数是:给定权重形状推断输入/输出形状,给定输入形状推断权重/输出形状,给定输出形状推断权重/输入形状。 模块替换功能返回一个较小的新创建的模块。
## 用法
```
python
from
nni.compression.speedup.torch
import
ModelSpeedup
# model: 要加速的模型
# dummy_input: 模型的示输入,传给 `jit.trace`
# masks_file: 剪枝算法创建的掩码文件
m_speedup
=
ModelSpeedup
(
model
,
dummy_input
.
to
(
device
),
masks_file
)
m_speedup
.
speedup_model
()
dummy_input
=
dummy_input
.
to
(
device
)
start
=
time
.
time
()
out
=
model
(
dummy_input
)
print
(
'elapsed time: '
,
time
.
time
()
-
start
)
```
完整示例参考
[
这里
](
https://github.com/microsoft/nni/tree/master/examples/model_compress/model_speedup.py
)
注意:当前实现仅用于 torch 1.3.1 和 torchvision 0.4.2
## 局限性
由于每个模块需要 4 个函数用于形状推理和模块替换,因此工作量较大,当前仅实现了示例所需的函数。 如果要加速自己的模型,但当前不支持,欢迎贡献。
对于 PyTorch,仅提供了替换模块,如果是在
`forward`
中的函数,当前不支持。 一种解决方案是将函数变为 PyTorch 模块。
## 示例的加速结果
实验代码可在
[
这里
](
https://github.com/microsoft/nni/tree/master/examples/model_compress/model_speedup.py
)
找到。
### slim Pruner 示例
在一块 V100 GPU 上, 输入张量:
`torch.randn(64, 3, 32, 32)`
| 次数 | 掩码时延 | 加速后的时延 |
| -- | ------- | -------- |
| 1 | 0.01197 | 0.005107 |
| 2 | 0.02019 | 0.008769 |
| 4 | 0.02733 | 0.014809 |
| 8 | 0.04310 | 0.027441 |
| 16 | 0.07731 | 0.05008 |
| 32 | 0.14464 | 0.10027 |
### fpgm Pruner 示例
在 CPU 上, 输入张量:
`torch.randn(64, 1, 28, 28)`
, 方差较大
| 次数 | 掩码时延 | 加速后的时延 |
| --- | ------- | -------- |
| 1 | 0.01383 | 0.01839 |
| 2 | 0.01167 | 0.003558 |
| 4 | 0.01636 | 0.01088 |
| 40 | 0.14412 | 0.08268 |
| 40 | 1.29385 | 0.14408 |
| 40 | 0.41035 | 0.46162 |
| 400 | 6.29020 | 5.82143 |
### l1filter Pruner 示例
在一块 V100 GPU 上, 输入张量:
`torch.randn(64, 3, 32, 32)`
| 次数 | 掩码时延 | 加速后的时延 |
| -- | ------- | -------- |
| 1 | 0.01026 | 0.003677 |
| 2 | 0.01657 | 0.008161 |
| 4 | 0.02458 | 0.020018 |
| 8 | 0.03498 | 0.025504 |
| 16 | 0.06757 | 0.047523 |
| 32 | 0.10487 | 0.086442 |
### APoZ Pruner 示例
在一块 V100 GPU 上, 输入张量:
`torch.randn(64, 3, 32, 32)`
| 次数 | 掩码时延 | 加速后的时延 |
| -- | ------- | -------- |
| 1 | 0.01389 | 0.004208 |
| 2 | 0.01628 | 0.008310 |
| 4 | 0.02521 | 0.014008 |
| 8 | 0.03386 | 0.023923 |
| 16 | 0.06042 | 0.046183 |
| 32 | 0.12421 | 0.087113 |
\ No newline at end of file
src/nni_manager/training_service/local/localTrainingService.ts
View file @
6568eaee
...
@@ -334,9 +334,11 @@ class LocalTrainingService implements TrainingService {
...
@@ -334,9 +334,11 @@ class LocalTrainingService implements TrainingService {
throw
new
Error
(
`Could not find stream in trial
${
trialJob
.
id
}
`
);
throw
new
Error
(
`Could not find stream in trial
${
trialJob
.
id
}
`
);
}
}
//Refer https://github.com/Juul/tail-stream/issues/20
//Refer https://github.com/Juul/tail-stream/issues/20
setTimeout
(()
=>
{
stream
.
end
(
0
);
stream
.
end
(
0
);
stream
.
emit
(
'
end
'
);
stream
.
emit
(
'
end
'
);
this
.
jobStreamMap
.
delete
(
trialJob
.
id
);
this
.
jobStreamMap
.
delete
(
trialJob
.
id
);
},
5000
);
}
}
}
}
if
(
trialJob
.
gpuIndices
!==
undefined
&&
trialJob
.
gpuIndices
.
length
>
0
&&
this
.
gpuScheduler
!==
undefined
)
{
if
(
trialJob
.
gpuIndices
!==
undefined
&&
trialJob
.
gpuIndices
.
length
>
0
&&
this
.
gpuScheduler
!==
undefined
)
{
...
...
src/sdk/pynni/nni/_graph_utils.py
0 → 100644
View file @
6568eaee
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
queue
import
re
from
collections
import
defaultdict
import
torch
from
torch.utils.tensorboard._pytorch_graph
import
NodePy
,
NodePyIO
,
NodePyOP
,
GraphPy
CLASSTYPE_KIND
=
'ClassType'
GETATTR_KIND
=
'prim::GetAttr'
_logger
=
logging
.
getLogger
(
__name__
)
def
build_module_graph
(
model
,
dummy_input
):
return
TorchModuleGraph
(
model
,
dummy_input
)
def
build_graph
(
model
,
dummy_input
,
verbose
=
False
):
g
=
TorchProtoGraph
(
model
,
dummy_input
,
verbose
)
return
g
.
graph_def
,
g
.
stepstats
def
parse_traced_name
(
module_name
):
prefix
=
'TracedModule['
suffix
=
']'
if
module_name
.
startswith
(
prefix
)
and
module_name
.
endswith
(
suffix
):
module_name
=
module_name
[
len
(
prefix
):
-
len
(
suffix
)]
return
module_name
class
TorchGraph
:
"""
This class is to extract pytorch model topology graph by tracing
"""
def
__init__
(
self
,
model
,
dummy_input
):
"""
Parameters
----------
model : pytorch model
The model user wants to speed up
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in
"""
assert
torch
.
__version__
>=
'1.3.1'
self
.
bound_model
=
model
self
.
_trace
(
model
,
dummy_input
)
def
_trace
(
self
,
model
,
dummy_input
):
with
torch
.
onnx
.
set_training
(
model
,
False
):
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
class
TorchProtoGraph
(
TorchGraph
):
"""
Generates model graph for pytorch models in protobuf, this implementation is borrowed from pytorch v1.4.0,
and fixed following issues:
https://github.com/pytorch/pytorch/issues/33691
https://github.com/pytorch/pytorch/issues/33670
"""
def
__init__
(
self
,
model
,
dummy_input
,
verbose
=
False
):
super
().
__init__
(
model
,
dummy_input
)
from
tensorboard.compat.proto.config_pb2
import
RunMetadata
from
tensorboard.compat.proto.graph_pb2
import
GraphDef
from
tensorboard.compat.proto.step_stats_pb2
import
StepStats
,
DeviceStepStats
from
tensorboard.compat.proto.versions_pb2
import
VersionDef
list_of_nodes
=
self
.
parse
(
self
.
trace
.
graph
,
self
.
trace
,
dummy_input
)
if
verbose
:
print
(
self
.
trace
.
graph
)
self
.
stepstats
=
RunMetadata
(
step_stats
=
StepStats
(
dev_stats
=
[
DeviceStepStats
(
device
=
"/device:CPU:0"
)]))
self
.
graph_def
=
GraphDef
(
node
=
list_of_nodes
,
versions
=
VersionDef
(
producer
=
22
))
def
parse
(
self
,
graph
,
trace
,
args
=
None
,
omit_useless_nodes
=
True
):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
nodes_py
=
GraphPy
()
for
node
in
graph
.
inputs
():
if
omit_useless_nodes
:
if
not
node
.
uses
():
# number of user of the node (= number of outputs/ fanout)
continue
if
node
.
type
().
kind
()
!=
CLASSTYPE_KIND
:
nodes_py
.
append
(
NodePyIO
(
node
,
'input'
))
attr_to_scope
=
dict
()
node_to_name
=
lambda
d
:
str
(
d
).
split
(
":"
)[
0
].
strip
()
for
node
in
graph
.
nodes
():
if
node
.
kind
()
==
GETATTR_KIND
:
attr_name
=
node
.
s
(
'name'
)
node_name
=
node_to_name
(
node
)
parent
=
node
.
input
().
node
()
if
parent
.
kind
()
==
GETATTR_KIND
:
# If the parent node is not the top-level "self" node
parent_scope
=
attr_to_scope
[
node_to_name
(
parent
)]
attr_scope
=
parent_scope
.
split
(
'/'
)[
-
1
]
attr_to_scope
[
node_name
]
=
'{}/{}.{}'
.
format
(
parent_scope
,
attr_scope
,
attr_name
)
else
:
attr_to_scope
[
node_name
]
=
'__module.{}'
.
format
(
attr_name
)
# We don't need classtype nodes; scope will provide this information
if
node
.
output
().
type
().
kind
()
!=
CLASSTYPE_KIND
:
node_py
=
NodePyOP
(
node
)
node_py
.
scopeName
=
attr_to_scope
[
node_name
]
nodes_py
.
append
(
node_py
)
else
:
nodes_py
.
append
(
NodePyOP
(
node
))
for
i
,
node
in
enumerate
(
graph
.
outputs
()):
# Create sink nodes for output ops
node_py
=
NodePyIO
(
node
,
'output'
)
node_py
.
debugName
=
"output.{}"
.
format
(
i
+
1
)
node_py
.
inputs
=
[
node
.
debugName
()]
nodes_py
.
append
(
node_py
)
alias_to_name
=
dict
()
base_name
=
parse_traced_name
(
trace
.
_name
)
for
name
,
module
in
trace
.
named_modules
(
prefix
=
'__module'
):
mod_name
=
parse_traced_name
(
module
.
_name
)
attr_name
=
name
.
split
(
'.'
)[
-
1
]
alias_to_name
[
name
]
=
'{}[{}]'
.
format
(
mod_name
,
attr_name
)
for
node
in
nodes_py
.
nodes_op
:
module_aliases
=
node
.
scopeName
.
split
(
'/'
)[
-
1
].
split
(
'.'
)
module_name
=
''
for
i
,
alias
in
enumerate
(
module_aliases
):
if
i
==
0
:
module_name
=
alias
node
.
scopeName
=
base_name
else
:
module_name
+=
'.'
+
alias
node
.
scopeName
+=
'/'
+
(
alias_to_name
[
module_name
]
if
module_name
in
alias_to_name
else
alias
)
nodes_py
.
populate_namespace_from_OP_to_IO
()
return
nodes_py
.
to_proto
()
class
NodePyGroup
(
NodePy
):
"""
This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph,
there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
"""
def
__init__
(
self
,
name
,
node_type
,
op_type
,
node_cpps
,
inputs
=
None
,
outputs
=
None
):
"""
Parameters:
-----------
name: str
node name, such as `conv1`, `backbone.classifier`
node_type: str
`module` or `func`
op_type: str
operation type, such as `Conv2d`, `aten::view`
node_cpps: list of torch._C.Node
jit trace nodes which are included in this new node
inputs: list of str
All the inputs of this node, each element is debugName of one input
outputs: list of str
All the outputs of this node, each element is debugName of one output
"""
super
(
NodePyGroup
,
self
).
__init__
(
name
,
[])
self
.
node_cpps
=
node_cpps
self
.
name
=
name
self
.
op_type
=
op_type
self
.
type
=
node_type
self
.
nodes
=
[]
self
.
auxiliary
=
None
self
.
add_nodes
(
node_cpps
)
self
.
inputs
=
inputs
self
.
outputs
=
outputs
def
add_nodes
(
self
,
node_cpps
):
for
node_cpp
in
node_cpps
:
nodepy
=
NodePyOP
(
node_cpp
)
nodepy
.
name
=
str
(
node_cpp
).
split
(
':'
)[
0
].
strip
().
replace
(
'%'
,
''
)
self
.
nodes
.
append
(
nodepy
)
def
sub_node_names
(
self
):
return
[
x
.
name
for
x
in
self
.
nodes
]
def
__repr__
(
self
):
return
'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'
.
format
(
self
.
name
,
self
.
type
,
self
.
op_type
,
self
.
sub_node_names
(),
self
.
inputs
,
self
.
outputs
,
self
.
auxiliary
)
class
TorchModuleGraph
(
TorchGraph
):
"""
Generates model graph, each node is created from single or multiple jit trace nodes.
"""
def
__init__
(
self
,
model
,
dummy_input
):
super
().
__init__
(
model
,
dummy_input
)
self
.
global_count
=
0
self
.
name_to_node
,
self
.
input_to_node
,
self
.
output_to_node
=
self
.
_build_graph
()
def
_expand_non_prim_node
(
self
,
node
,
nodes
,
input_to_node
,
output_to_node
):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
trivial op which are label by ```prim::```, some of them are not such ops which is call
non-prim ops. This function is to merge neighbor prim ops to a non-prim op, to construct
a node.
Parameters
----------
node : trace graph node
The non-prim node to expand
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
Returns
-------
node
the expanded non-prim node
"""
# TODO: scope name could be empty
node_name
=
'.'
.
join
([
self
.
_get_module_name
(
node
.
scopeName
()),
node
.
kind
(),
str
(
self
.
global_count
)])
_logger
.
debug
(
"expand non-prim node, node name: %s"
,
node_name
)
self
.
global_count
+=
1
op_type
=
node
.
kind
()
node_group
=
[
node
]
inputs
=
list
()
outputs
=
list
()
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
while
not
node_queue
.
empty
():
curr_node
=
node_queue
.
get
()
for
_input
in
curr_node
.
inputs
():
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
predecessor_node
=
output_to_node
[
input_name
]
if
predecessor_node
.
kind
().
startswith
(
'prim::'
):
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
else
:
inputs
.
append
(
input_name
)
else
:
inputs
.
append
(
input_name
)
for
output
in
node
.
outputs
():
outputs
.
append
(
output
.
debugName
())
nodepy
=
NodePyGroup
(
node_name
,
'func'
,
op_type
,
node_group
,
inputs
=
inputs
,
outputs
=
outputs
)
return
nodepy
def
_build_module_node_group
(
self
,
module_name
,
op_type
,
node_cpps
,
input_to_node
,
output_to_node
):
graph
=
self
.
trace
.
graph
inputs
,
outputs
=
[],
[]
for
n
in
node_cpps
:
for
i
in
n
.
inputs
():
name
=
i
.
debugName
()
if
not
name
in
output_to_node
and
i
in
graph
.
inputs
():
inputs
.
append
(
name
)
elif
output_to_node
[
name
]
not
in
node_cpps
:
inputs
.
append
(
name
)
for
o
in
n
.
outputs
():
name
=
o
.
debugName
()
if
not
name
in
input_to_node
and
o
in
graph
.
outputs
():
outputs
.
append
(
name
)
elif
input_to_node
[
name
]
not
in
node_cpps
:
outputs
.
append
(
name
)
return
NodePyGroup
(
module_name
,
'module'
,
op_type
,
node_cpps
,
inputs
,
outputs
)
def
_extract_shape_info
(
self
,
node
):
"""
Extract the shape information of ```aten::view``` node
Parameters
----------
node : trace graph node
It should be ```aten::view``` node
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
t_input
=
None
for
_input
in
node
.
inputs
():
t_input
=
_input
break
t_output
=
node
.
output
()
assert
isinstance
(
t_input
.
type
(),
torch
.
_C
.
TensorType
)
assert
isinstance
(
t_output
.
type
(),
torch
.
_C
.
TensorType
)
in_shape
=
t_input
.
type
().
sizes
()
out_shape
=
t_output
.
type
().
sizes
()
return
{
'in_shape'
:
in_shape
,
'out_shape'
:
out_shape
}
def
_extract_leaf_modules
(
self
):
"""
Extract leaf modules from the given graph. Leaf module means it does not have submodules.
To extract leaf modules because only leaf module can be replaced. And shape inference can
be done in leaf module level. Other shape inference is done in lower level i.e.,
operation level.
Returns
-------
list
a list of scope name of all the leaf modules
"""
def
is_parent
(
name1
,
name2
):
"""
check if name1 is parent node of name2, for example:
name1: aa.bb, name2: aa.bb.cc, return True
name1: aa.b, name2: aa.bb, return False
"""
parts1
,
parts2
=
name1
.
split
(
'.'
),
name2
.
split
(
'.'
)
if
len
(
parts1
)
>=
len
(
parts2
):
return
False
for
i
in
range
(
len
(
parts1
)):
if
parts2
[
i
]
!=
parts1
[
i
]:
return
False
return
True
module_names
=
sorted
([
x
[
0
]
for
x
in
self
.
trace
.
named_modules
()
if
x
[
0
]])
leaf_nodes
=
[]
for
i
,
name
in
enumerate
(
module_names
):
if
i
+
1
>=
len
(
module_names
)
or
not
is_parent
(
name
,
module_names
[
i
+
1
]):
leaf_nodes
.
append
(
name
)
return
leaf_nodes
def
_get_module_name
(
self
,
scope_name
):
"""
Retrieve module name from scope name.
Parameters:
-----------
scope_name: str
scope_name of a graph node, for example:
for pytorch 1.3.1: MyModel/BackboneModel[backbone]/Conv2d[conv2]
for pytorch 1.4.0: __module.backbone/__module.backbone.conv2
Returns:
-------
str
module name, such as backbone.conv2
"""
if
torch
.
__version__
>=
'1.4.0'
:
return
scope_name
.
split
(
'/'
)[
-
1
].
replace
(
'__module.'
,
''
)
else
:
return
'.'
.
join
(
re
.
findall
(
r
'\[(.*?)\]'
,
scope_name
))
def
_build_index
(
self
,
nodes_op
):
name_to_node
=
dict
()
input_to_node
=
defaultdict
(
list
)
output_to_node
=
dict
()
for
node
in
nodes_op
:
name_to_node
[
node
.
name
]
=
node
for
_input
in
node
.
inputs
:
input_to_node
[
_input
].
append
(
node
)
for
output
in
node
.
outputs
:
assert
not
output
in
output_to_node
,
\
"One output cannot be generated by multiple nodes"
output_to_node
[
output
]
=
node
return
name_to_node
,
input_to_node
,
output_to_node
def
_build_graph
(
self
):
"""
Build graph using our defined format from jit trace.
There are basically three steps: first, construct necessary information (data structures),
second, extract all the modules to convert to node, Third, extract all functions to convert
to node.
Returns
-------
dict
use name to index nodes, key: node name, value: node
dict
use input (its name) to index nodes,
key: input, value: list of nodes that take this input
dict
use output (its name) to index nodes,
key: output, value: node that generates this output
"""
omit_useless_nodes
=
True
graph
=
self
.
trace
.
graph
_logger
.
debug
(
graph
)
# build output mapping, from output debugName to its node
output_to_node
=
{
x
.
debugName
():
n
for
n
in
graph
.
nodes
()
for
x
in
n
.
outputs
()}
# build input mapping, from input debugName to its node
input_to_node
=
{
x
.
debugName
():
n
for
n
in
graph
.
nodes
()
for
x
in
n
.
inputs
()}
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes
=
defaultdict
(
list
)
# the mapping of function (non-module in forward) to nodes, key is scope name
func_to_nodes
=
defaultdict
(
list
)
nodes_py
=
GraphPy
()
for
node
in
graph
.
inputs
():
if
omit_useless_nodes
:
if
not
node
.
uses
():
# number of user of the node (= number of outputs/ fanout)
continue
if
node
.
type
().
kind
()
!=
'ClassType'
:
nodes_py
.
append
(
NodePyIO
(
node
,
'input'
))
self
.
leaf_modules
=
self
.
_extract_leaf_modules
()
module_to_type
=
{
name
:
parse_traced_name
(
module
.
_name
)
for
name
,
module
in
self
.
trace
.
named_modules
()}
# associate module name with their trace graph nodes
for
node
in
graph
.
nodes
():
module_name
=
self
.
_get_module_name
(
node
.
scopeName
())
if
module_name
in
self
.
leaf_modules
:
module_to_nodes
[
module_name
].
append
(
node
)
else
:
func_to_nodes
[
node
.
scopeName
()].
append
(
node
)
# build node group for module
for
module_name
,
node_cpps
in
module_to_nodes
.
items
():
node_group
=
self
.
_build_module_node_group
(
module_name
,
module_to_type
[
module_name
],
node_cpps
,
input_to_node
,
output_to_node
)
_logger
.
debug
(
'node_group: %s'
,
node_group
)
nodes_py
.
nodes_op
.
append
(
node_group
)
# each scope_name may have multiple funcs, we split them and create node for each of them
# build node group for torch.nn.functional
for
_
,
nodes
in
func_to_nodes
.
items
():
# extract non prim:: nodes
non_prim_nodes
=
list
()
for
node
in
nodes
:
if
not
node
.
kind
().
startswith
(
'prim::'
):
non_prim_nodes
.
append
(
node
)
# for each non prim node, expand it
for
node
in
non_prim_nodes
:
node_group
=
self
.
_expand_non_prim_node
(
node
,
nodes
,
input_to_node
,
output_to_node
)
nodes_py
.
nodes_op
.
append
(
node_group
)
# get shape infor for view (aten::view) func
if
node_group
.
op_type
in
[
'aten::view'
,
'aten::flatten'
]:
node_group
.
auxiliary
=
self
.
_extract_shape_info
(
node
)
for
node
in
graph
.
outputs
():
# Create sink nodes for output ops
node_py
=
NodePyIO
(
node
,
'output'
)
nodes_py
.
append
(
node_py
)
self
.
nodes_py
=
nodes_py
# build index
return
self
.
_build_index
(
self
.
nodes_py
.
nodes_op
)
def
find_predecessors
(
self
,
module_name
):
"""
Find predecessor node of the given node
Parameters
----------
module_name : str
The name of the node
Returns
-------
list
a list of nodes who are the given node's predecessor
"""
predecessors
=
[]
for
_input
in
self
.
name_to_node
[
module_name
].
inputs
:
if
not
_input
in
self
.
output_to_node
:
_logger
.
debug
(
"cannot find node with %s as its output"
,
_input
)
else
:
node_py
=
self
.
output_to_node
[
_input
]
predecessors
.
append
(
node_py
.
name
)
return
predecessors
def
find_successors
(
self
,
module_name
):
"""
Find successor nodes of the given node
Parameters
----------
module_name : str
The name of the node
Returns
-------
list
a list of nodes who are the given node's successor
"""
successors
=
[]
for
output
in
self
.
name_to_node
[
module_name
].
outputs
:
assert
output
in
self
.
input_to_node
,
"No node with input {}"
.
format
(
output
)
nodes_py
=
self
.
input_to_node
[
output
]
for
node_py
in
nodes_py
:
successors
.
append
(
node_py
.
name
)
return
successors
src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py
View file @
6568eaee
...
@@ -12,6 +12,7 @@ replace_module = {
...
@@ -12,6 +12,7 @@ replace_module = {
'Conv2d'
:
lambda
module
,
mask
:
replace_conv2d
(
module
,
mask
),
'Conv2d'
:
lambda
module
,
mask
:
replace_conv2d
(
module
,
mask
),
'MaxPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'MaxPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AdaptiveAvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'ReLU'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'ReLU'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Linear'
:
lambda
module
,
mask
:
replace_linear
(
module
,
mask
)
'Linear'
:
lambda
module
,
mask
:
replace_linear
(
module
,
mask
)
}
}
...
...
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
View file @
6568eaee
...
@@ -2,9 +2,8 @@
...
@@ -2,9 +2,8 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
import
queue
import
re
import
torch
import
torch
from
nni._graph_utils
import
build_module_graph
from
.compress_modules
import
replace_module
from
.compress_modules
import
replace_module
from
.infer_shape
import
ModuleMasks
,
infer_from_mask
,
infer_from_inshape
,
infer_from_outshape
from
.infer_shape
import
ModuleMasks
,
infer_from_mask
,
infer_from_inshape
,
infer_from_outshape
...
@@ -33,38 +32,6 @@ def get_module_by_name(model, module_name):
...
@@ -33,38 +32,6 @@ def get_module_by_name(model, module_name):
leaf_module
=
getattr
(
model
,
name_list
[
-
1
])
leaf_module
=
getattr
(
model
,
name_list
[
-
1
])
return
model
,
leaf_module
return
model
,
leaf_module
class
GNode
:
"""
It is used to represent a node in model graph, in this graph a module is a node,
a function out of module (in ```forward``` function) could also be a node.
"""
def
__init__
(
self
,
node_name
,
node_type
,
op_type
,
inputs
,
outputs
,
nodes
):
"""
Parameters
----------
node_name : str
It is module name if the node is a module, it is ```scope_name.node_kind.seq``` if it is a func
node_type : str
It only has two options: `module` or `func`
op_type : str
The operation type of the module or func
inputs : list of str
All the inputs of this node, each element is debugName of one input
outputs : list of str
All the outputs of this node, each element is debugName of one output
nodes : list of node
All the trace graph nodes included in this module or func
"""
self
.
name
=
node_name
self
.
type
=
node_type
self
.
op_type
=
op_type
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
nodes
=
nodes
# store supplementary information for different op types
# for example, for ```view``` it stores the shape of its input and output
self
.
auxiliary
=
None
class
ModelSpeedup
:
class
ModelSpeedup
:
"""
"""
This class is to speedup the model with provided weight mask
This class is to speedup the model with provided weight mask
...
@@ -84,347 +51,9 @@ class ModelSpeedup:
...
@@ -84,347 +51,9 @@ class ModelSpeedup:
the device on which masks are placed, same to map_location in ```torch.load```
the device on which masks are placed, same to map_location in ```torch.load```
"""
"""
self
.
bound_model
=
model
self
.
bound_model
=
model
self
.
dummy_input
=
dummy_input
self
.
masks
=
torch
.
load
(
masks_file
,
map_location
)
self
.
masks
=
torch
.
load
(
masks_file
,
map_location
)
self
.
is_training
=
model
.
training
# to obtain forward graph, model should be in ```eval``` mode
if
self
.
is_training
:
model
.
eval
()
self
.
trace_graph
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
if
self
.
is_training
:
model
.
train
()
self
.
inferred_masks
=
dict
()
# key: module_name, value: ModuleMasks
self
.
inferred_masks
=
dict
()
# key: module_name, value: ModuleMasks
self
.
g_nodes
=
list
()
self
.
torch_graph
=
build_module_graph
(
model
,
dummy_input
)
self
.
global_count
=
0
self
.
name_to_gnode
,
self
.
input_to_gnode
,
self
.
output_to_gnode
=
self
.
_build_graph
()
def
_build_index_for_gnodes
(
self
,
g_nodes
):
"""
Build indexes for quick search
Parameters
----------
g_nodes : list of GNode
All the g_node in processed model graph
Returns
-------
dict
use name to index g_nodes, key: node name, value: g_node
dict
use input (its name) to index g_nodes,
key: input, value: list of g_nodes that take this input
dict
use output (its name) to index g_nodes,
key: output, value: g_node that generates this output
"""
name_to_gnode
=
dict
()
input_to_gnode
=
dict
()
output_to_gnode
=
dict
()
for
node
in
g_nodes
:
name_to_gnode
[
node
.
name
]
=
node
for
_input
in
node
.
inputs
:
if
_input
in
input_to_gnode
:
input_to_gnode
[
_input
].
append
(
node
)
else
:
input_to_gnode
[
_input
]
=
[
node
]
for
output
in
node
.
outputs
:
assert
not
output
in
output_to_gnode
,
\
"One output cannot be generated by multiple nodes"
output_to_gnode
[
output
]
=
node
return
name_to_gnode
,
input_to_gnode
,
output_to_gnode
def
_expand_non_prim_node
(
self
,
node
,
nodes
,
input_to_node
,
output_to_node
):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
trivial op which are label by ```prim::```, some of them are not such ops which is call
non-prim ops. This function is to merge neighbor prim ops to a non-prim op, to construct
a GNode.
Parameters
----------
node : trace graph node
The non-prim node to expand
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
Returns
-------
GNode
the expanded non-prim node in GNode format
"""
# TODO: scope name could be empty
node_name
=
'.'
.
join
([
node
.
scopeName
(),
node
.
kind
(),
str
(
self
.
global_count
)])
_logger
.
debug
(
"expand non-prim node, node name: %s"
,
node_name
)
self
.
global_count
+=
1
op_type
=
node
.
kind
()
node_group
=
[
node
]
inputs
=
list
()
outputs
=
list
()
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
while
not
node_queue
.
empty
():
curr_node
=
node_queue
.
get
()
for
_input
in
curr_node
.
inputs
():
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
predecessor_node
=
output_to_node
[
input_name
]
if
predecessor_node
.
kind
().
startswith
(
'prim::'
):
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
else
:
inputs
.
append
(
input_name
)
else
:
inputs
.
append
(
input_name
)
for
output
in
node
.
outputs
():
outputs
.
append
(
output
.
debugName
())
g_node
=
GNode
(
node_name
,
'func'
,
op_type
,
inputs
,
outputs
,
node_group
)
return
g_node
def
_extract_shape_info
(
self
,
node
):
"""
Extract the shape information of ```aten::view``` node
Parameters
----------
node : trace graph node
It should be ```aten::view``` node
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
t_input
=
None
for
_input
in
node
.
inputs
():
t_input
=
_input
break
t_output
=
node
.
output
()
assert
isinstance
(
t_input
.
type
(),
torch
.
_C
.
TensorType
)
assert
isinstance
(
t_output
.
type
(),
torch
.
_C
.
TensorType
)
in_shape
=
t_input
.
type
().
sizes
()
out_shape
=
t_output
.
type
().
sizes
()
return
{
'in_shape'
:
in_shape
,
'out_shape'
:
out_shape
}
def
_extract_leaf_modules
(
self
,
graph
):
"""
Extract leaf modules from the given graph. Leaf module means it does not have submodules.
To extract leaf modules because only leaf module can be replaced. And shape inference can
be done in leaf module level. Other shape inference is done in lower level i.e.,
operation level.
Parameters
----------
graph : jit trace graph
the graph generated from jit trace
Returns
-------
list
a list of scope name of all the leaf modules
"""
class
SNode
:
def
__init__
(
self
,
name
):
self
.
sname
=
name
self
.
childs
=
{}
root
=
None
for
node
in
graph
.
nodes
():
scope_name
=
node
.
scopeName
()
if
scope_name
==
''
:
continue
segs
=
scope_name
.
split
(
'/'
)
if
root
is
None
:
root
=
SNode
(
segs
[
0
])
curr
=
root
for
seg
in
segs
[
1
:]:
if
not
seg
in
curr
.
childs
:
curr
.
childs
[
seg
]
=
SNode
(
seg
)
curr
=
curr
.
childs
[
seg
]
leaf_nodes
=
[]
def
traverse_tree
(
node
,
scope_name
):
if
scope_name
==
''
:
sn
=
node
.
sname
else
:
sn
=
scope_name
+
'/'
+
node
.
sname
if
not
node
.
childs
:
if
node
.
sname
[
-
1
]
==
']'
:
leaf_nodes
.
append
(
sn
)
else
:
for
key
in
node
.
childs
:
traverse_tree
(
node
.
childs
[
key
],
sn
)
traverse_tree
(
root
,
''
)
return
leaf_nodes
def
_build_graph
(
self
):
"""
Build graph using our defined format from jit trace.
There are basically three steps: first, construct necessary information (data structures),
second, extract all the modules to convert to GNode, Third, extract all functions to convert
to GNode.
Returns
-------
dict
use name to index g_nodes, key: node name, value: g_node
dict
use input (its name) to index g_nodes,
key: input, value: list of g_nodes that take this input
dict
use output (its name) to index g_nodes,
key: output, value: g_node that generates this output
"""
graph
=
self
.
trace_graph
.
graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
_logger
.
debug
(
graph
)
# build output mapping, from output debugName to its node
output_to_node
=
dict
()
# build input mapping, from input debugName to its node
input_to_node
=
dict
()
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes
=
dict
()
# module name to its type
module_to_type
=
dict
()
# the mapping of function (non-module in forward) to nodes, key is scope name
func_to_nodes
=
dict
()
graph_inputs
=
list
()
graph_outputs
=
list
()
for
_input
in
graph
.
inputs
():
graph_inputs
.
append
(
_input
.
debugName
())
for
output
in
graph
.
outputs
():
graph_outputs
.
append
(
output
.
debugName
())
leaf_modules
=
self
.
_extract_leaf_modules
(
graph
)
_logger
.
debug
(
leaf_modules
)
for
node
in
graph
.
nodes
():
# populate output_to_node and input_to_node
for
output
in
node
.
outputs
():
output_name
=
output
.
debugName
()
output_to_node
[
output_name
]
=
node
for
_input
in
node
.
inputs
():
input_name
=
_input
.
debugName
()
input_to_node
[
input_name
]
=
node
scope_name
=
node
.
scopeName
()
# example: scope_name, 'MyCell/Linear[linear]'
# if module_name is empty, it is not a module
if
not
scope_name
in
leaf_modules
:
if
scope_name
==
''
:
continue
else
:
if
scope_name
in
func_to_nodes
:
func_to_nodes
[
scope_name
].
append
(
node
)
else
:
func_to_nodes
[
scope_name
]
=
[
node
]
else
:
module_name_slices
=
re
.
findall
(
r
'\[(.*?)\]'
,
scope_name
)
module_name
=
'.'
.
join
(
module_name_slices
)
scope_slice
=
scope_name
.
split
(
'/'
)[
-
1
]
module_type
=
scope_slice
.
split
(
'['
)[
0
]
module_to_type
[
module_name
]
=
module_type
if
module_name
in
module_to_nodes
:
module_to_nodes
[
module_name
].
append
(
node
)
else
:
module_to_nodes
[
module_name
]
=
[
node
]
# construct GNode from module
for
module_name
,
nodes
in
module_to_nodes
.
items
():
inputs
=
set
()
outputs
=
set
()
for
node
in
nodes
:
for
output
in
node
.
outputs
():
outputs
.
add
(
output
.
debugName
())
for
_input
in
node
.
inputs
():
inputs
.
add
(
_input
.
debugName
())
m_inputs
=
list
()
m_outputs
=
list
()
for
output
in
outputs
:
# TODO: one input could be the input of multiple nodes
if
not
output
in
input_to_node
and
output
in
graph_outputs
:
m_outputs
.
append
(
output
)
elif
not
input_to_node
[
output
]
in
nodes
:
m_outputs
.
append
(
output
)
for
_input
in
inputs
:
if
not
_input
in
output_to_node
and
_input
in
graph_inputs
:
m_inputs
.
append
(
_input
)
elif
not
output_to_node
[
_input
]
in
nodes
:
m_inputs
.
append
(
_input
)
if
module_name
==
''
:
_logger
.
warning
(
"module_name is empty string"
)
g_node
=
GNode
(
module_name
,
'module'
,
module_to_type
[
module_name
],
m_inputs
,
m_outputs
,
nodes
)
self
.
g_nodes
.
append
(
g_node
)
# each scope_name may have multiple funcs, we split them and create GNode for each of them
for
scope_name
,
nodes
in
func_to_nodes
.
items
():
# extract non prim:: nodes
non_prim_nodes
=
list
()
for
node
in
nodes
:
if
not
node
.
kind
().
startswith
(
'prim::'
):
non_prim_nodes
.
append
(
node
)
# for each non prim node, expand it has a GNode
for
node
in
non_prim_nodes
:
g_node
=
self
.
_expand_non_prim_node
(
node
,
nodes
,
input_to_node
,
output_to_node
)
self
.
g_nodes
.
append
(
g_node
)
# get shape infor for view (aten::view) func
if
g_node
.
op_type
==
'aten::view'
:
g_node
.
auxiliary
=
self
.
_extract_shape_info
(
node
)
# build index for g_nodes
name_to_gnode
,
input_to_gnode
,
output_to_gnode
=
self
.
_build_index_for_gnodes
(
self
.
g_nodes
)
return
name_to_gnode
,
input_to_gnode
,
output_to_gnode
def
_find_predecessors
(
self
,
module_name
):
"""
Find predecessor GNode of the given GNode
Parameters
----------
module_name : str
The name of the GNode
Returns
-------
list
a list of GNodes who are the given GNode's predecessor
"""
predecessors
=
[]
for
_input
in
self
.
name_to_gnode
[
module_name
].
inputs
:
if
not
_input
in
self
.
output_to_gnode
:
_logger
.
debug
(
"cannot find gnode with %s as its output"
,
_input
)
else
:
g_node
=
self
.
output_to_gnode
[
_input
]
predecessors
.
append
(
g_node
.
name
)
return
predecessors
def
_find_successors
(
self
,
module_name
):
"""
Find successor GNodes of the given GNode
Parameters
----------
module_name : str
The name of the GNode
Returns
-------
list
a list of GNodes who are the given GNode's successor
"""
successors
=
[]
for
output
in
self
.
name_to_gnode
[
module_name
].
outputs
:
assert
output
in
self
.
input_to_gnode
,
"No gnode with input {}"
.
format
(
output
)
g_nodes
=
self
.
input_to_gnode
[
output
]
for
g_node
in
g_nodes
:
successors
.
append
(
g_node
.
name
)
return
successors
def
infer_module_mask
(
self
,
module_name
,
mask
=
None
,
in_shape
=
None
,
out_shape
=
None
):
def
infer_module_mask
(
self
,
module_name
,
mask
=
None
,
in_shape
=
None
,
out_shape
=
None
):
"""
"""
...
@@ -441,13 +70,13 @@ class ModelSpeedup:
...
@@ -441,13 +70,13 @@ class ModelSpeedup:
Parameters
Parameters
----------
----------
module_name : str
module_name : str
The name of the
GN
ode
The name of the
n
ode
mask : tensor of mask or ModuleMasks
mask : tensor of mask or ModuleMasks
Mask of the weights in this
GN
ode (i.e., module)
Mask of the weights in this
n
ode (i.e., module)
in_shape : ModuleMasks
in_shape : ModuleMasks
Input shape of this
GN
ode
Input shape of this
n
ode
out_shape : ModuleMasks
out_shape : ModuleMasks
Output shape of this
GN
ode
Output shape of this
n
ode
"""
"""
input_cmask
=
output_cmask
=
None
input_cmask
=
output_cmask
=
None
if
module_name
in
self
.
inferred_masks
:
if
module_name
in
self
.
inferred_masks
:
...
@@ -456,7 +85,7 @@ class ModelSpeedup:
...
@@ -456,7 +85,7 @@ class ModelSpeedup:
module_masks
=
ModuleMasks
(
module_name
)
module_masks
=
ModuleMasks
(
module_name
)
self
.
inferred_masks
[
module_name
]
=
module_masks
self
.
inferred_masks
[
module_name
]
=
module_masks
m_type
=
self
.
name_to_
g
node
[
module_name
].
op_type
m_type
=
self
.
torch_graph
.
name_to_node
[
module_name
].
op_type
_logger
.
debug
(
"infer mask of module %s with op_type %s"
,
module_name
,
m_type
)
_logger
.
debug
(
"infer mask of module %s with op_type %s"
,
module_name
,
m_type
)
if
mask
is
not
None
:
if
mask
is
not
None
:
_logger
.
debug
(
"mask is not None"
)
_logger
.
debug
(
"mask is not None"
)
...
@@ -471,10 +100,10 @@ class ModelSpeedup:
...
@@ -471,10 +100,10 @@ class ModelSpeedup:
raise
RuntimeError
(
raise
RuntimeError
(
"Has not supported infering output shape from input shape for module/function: `{}`, {}"
"Has not supported infering output shape from input shape for module/function: `{}`, {}"
.
format
(
m_type
,
module_name
))
.
format
(
m_type
,
module_name
))
if
m_type
==
'aten::view'
:
if
m_type
in
[
'aten::view'
,
'aten::flatten'
]
:
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
,
in_shape
,
self
.
name_to_
g
node
[
module_name
].
auxiliary
)
self
.
torch_graph
.
name_to_node
[
module_name
].
auxiliary
)
else
:
else
:
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
)
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
)
if
out_shape
is
not
None
:
if
out_shape
is
not
None
:
...
@@ -486,11 +115,11 @@ class ModelSpeedup:
...
@@ -486,11 +115,11 @@ class ModelSpeedup:
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
)
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
)
if
input_cmask
:
if
input_cmask
:
predecessors
=
self
.
_
find_predecessors
(
module_name
)
predecessors
=
self
.
torch_graph
.
find_predecessors
(
module_name
)
for
_module_name
in
predecessors
:
for
_module_name
in
predecessors
:
self
.
infer_module_mask
(
_module_name
,
out_shape
=
input_cmask
)
self
.
infer_module_mask
(
_module_name
,
out_shape
=
input_cmask
)
if
output_cmask
:
if
output_cmask
:
successors
=
self
.
_
find_successors
(
module_name
)
successors
=
self
.
torch_graph
.
find_successors
(
module_name
)
for
_module_name
in
successors
:
for
_module_name
in
successors
:
self
.
infer_module_mask
(
_module_name
,
in_shape
=
output_cmask
)
self
.
infer_module_mask
(
_module_name
,
in_shape
=
output_cmask
)
...
@@ -511,7 +140,7 @@ class ModelSpeedup:
...
@@ -511,7 +140,7 @@ class ModelSpeedup:
is that ```func``` should be not required to be replaced.
is that ```func``` should be not required to be replaced.
"""
"""
for
module_name
in
self
.
inferred_masks
:
for
module_name
in
self
.
inferred_masks
:
g_node
=
self
.
name_to_
g
node
[
module_name
]
g_node
=
self
.
torch_graph
.
name_to_node
[
module_name
]
_logger
.
debug
(
"replace %s, in %s type, with op_type %s"
,
_logger
.
debug
(
"replace %s, in %s type, with op_type %s"
,
module_name
,
g_node
.
type
,
g_node
.
op_type
)
module_name
,
g_node
.
type
,
g_node
.
op_type
)
if
g_node
.
type
==
'module'
:
if
g_node
.
type
==
'module'
:
...
@@ -526,7 +155,7 @@ class ModelSpeedup:
...
@@ -526,7 +155,7 @@ class ModelSpeedup:
_logger
.
info
(
"Warning: cannot replace (name: %s, op_type: %s) which is func type"
,
_logger
.
info
(
"Warning: cannot replace (name: %s, op_type: %s) which is func type"
,
module_name
,
g_node
.
op_type
)
module_name
,
g_node
.
op_type
)
else
:
else
:
raise
RuntimeError
(
"Unsupported
GN
ode type: {}"
.
format
(
g_node
.
type
))
raise
RuntimeError
(
"Unsupported
n
ode type: {}"
.
format
(
g_node
.
type
))
def
speedup_model
(
self
):
def
speedup_model
(
self
):
"""
"""
...
@@ -540,8 +169,3 @@ class ModelSpeedup:
...
@@ -540,8 +169,3 @@ class ModelSpeedup:
_logger
.
info
(
"replace compressed modules..."
)
_logger
.
info
(
"replace compressed modules..."
)
self
.
replace_compressed_modules
()
self
.
replace_compressed_modules
()
_logger
.
info
(
"speedup done"
)
_logger
.
info
(
"speedup done"
)
# resume the model mode to that before the model is speed up
if
self
.
is_training
:
self
.
bound_model
.
train
()
else
:
self
.
bound_model
.
eval
()
\ No newline at end of file
src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py
View file @
6568eaee
...
@@ -83,6 +83,9 @@ class CoarseMask:
...
@@ -83,6 +83,9 @@ class CoarseMask:
cmask
.
mask_index
[
i
])
cmask
.
mask_index
[
i
])
return
self
.
mask_index
return
self
.
mask_index
def
__repr__
(
self
):
return
'mask_index: {}'
.
format
(
self
.
mask_index
)
class
ModuleMasks
:
class
ModuleMasks
:
"""
"""
The masks of a module, including the masks for weights, inputs, output
The masks of a module, including the masks for weights, inputs, output
...
@@ -128,6 +131,11 @@ class ModuleMasks:
...
@@ -128,6 +131,11 @@ class ModuleMasks:
"""
"""
self
.
output_mask
=
mask
self
.
output_mask
=
mask
def
__repr__
(
self
):
return
'input_mask: {}, output_mask: {}, param_masks: {}'
.
format
(
self
.
input_mask
,
self
.
output_mask
,
self
.
param_masks
)
"""
"""
Infer input and output shape of a module/function from its weight mask
Infer input and output shape of a module/function from its weight mask
"""
"""
...
@@ -147,8 +155,10 @@ infer_from_inshape = {
...
@@ -147,8 +155,10 @@ infer_from_inshape = {
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'AvgPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'AvgPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'AdaptiveAvgPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::size'
:
lambda
module_masks
,
mask
:
size_inshape
(
module_masks
,
mask
),
'aten::size'
:
lambda
module_masks
,
mask
:
size_inshape
(
module_masks
,
mask
),
'aten::view'
:
lambda
module_masks
,
mask
,
shape
:
view_inshape
(
module_masks
,
mask
,
shape
),
'aten::view'
:
lambda
module_masks
,
mask
,
shape
:
view_inshape
(
module_masks
,
mask
,
shape
),
'aten::flatten'
:
lambda
module_masks
,
mask
,
shape
:
view_inshape
(
module_masks
,
mask
,
shape
),
# support only start_dim=1
'Linear'
:
lambda
module_masks
,
mask
:
linear_inshape
(
module_masks
,
mask
),
'Linear'
:
lambda
module_masks
,
mask
:
linear_inshape
(
module_masks
,
mask
),
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_inshape
(
module_masks
,
mask
)
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_inshape
(
module_masks
,
mask
)
}
}
...
...
src/sdk/pynni/nni/nas/pytorch/_graph_utils.py
deleted
100644 → 0
View file @
d90433da
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
# This file is copied from PyTorch 1.4, with bug fixes.
# Likely to be removed in future.
import
torch
from
tensorboard.compat.proto.config_pb2
import
RunMetadata
from
tensorboard.compat.proto.graph_pb2
import
GraphDef
from
tensorboard.compat.proto.step_stats_pb2
import
StepStats
,
DeviceStepStats
from
tensorboard.compat.proto.versions_pb2
import
VersionDef
from
torch.utils.tensorboard._pytorch_graph
import
GraphPy
,
CLASSTYPE_KIND
,
GETATTR_KIND
,
NodePyIO
,
NodePyOP
def
parse
(
graph
,
trace
,
args
=
None
,
omit_useless_nodes
=
True
):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
n_inputs
=
len
(
args
)
scope
=
{}
nodes_py
=
GraphPy
()
for
node
in
graph
.
inputs
():
if
omit_useless_nodes
:
if
len
(
node
.
uses
())
==
0
:
# number of user of the node (= number of outputs/ fanout)
continue
if
node
.
type
().
kind
()
!=
CLASSTYPE_KIND
:
nodes_py
.
append
(
NodePyIO
(
node
,
'input'
))
attr_to_scope
=
dict
()
node_to_name
=
lambda
d
:
str
(
d
).
split
(
":"
)[
0
].
strip
()
for
node
in
graph
.
nodes
():
if
node
.
kind
()
==
GETATTR_KIND
:
attr_name
=
node
.
s
(
'name'
)
node_name
=
node_to_name
(
node
)
parent
=
node
.
input
().
node
()
if
parent
.
kind
()
==
GETATTR_KIND
:
# If the parent node is not the top-level "self" node
parent_attr_name
=
parent
.
s
(
'name'
)
parent_scope
=
attr_to_scope
[
node_to_name
(
parent
)]
attr_scope
=
parent_scope
.
split
(
'/'
)[
-
1
]
attr_to_scope
[
node_name
]
=
'{}/{}.{}'
.
format
(
parent_scope
,
attr_scope
,
attr_name
)
else
:
attr_to_scope
[
node_name
]
=
'__module.{}'
.
format
(
attr_name
)
# We don't need classtype nodes; scope will provide this information
if
node
.
output
().
type
().
kind
()
!=
CLASSTYPE_KIND
:
node_py
=
NodePyOP
(
node
)
node_py
.
scopeName
=
attr_to_scope
[
node_name
]
nodes_py
.
append
(
node_py
)
else
:
nodes_py
.
append
(
NodePyOP
(
node
))
for
i
,
node
in
enumerate
(
graph
.
outputs
()):
# Create sink nodes for output ops
node_py
=
NodePyIO
(
node
,
'output'
)
node_py
.
debugName
=
"output.{}"
.
format
(
i
+
1
)
node_py
.
inputs
=
[
node
.
debugName
()]
nodes_py
.
append
(
node_py
)
def
parse_traced_name
(
module_name
):
prefix
=
'TracedModule['
suffix
=
']'
if
module_name
.
startswith
(
prefix
)
and
module_name
.
endswith
(
suffix
):
module_name
=
module_name
[
len
(
prefix
):
-
len
(
suffix
)]
return
module_name
alias_to_name
=
dict
()
base_name
=
parse_traced_name
(
trace
.
_name
)
for
name
,
module
in
trace
.
named_modules
(
prefix
=
'__module'
):
mod_name
=
parse_traced_name
(
module
.
_name
)
attr_name
=
name
.
split
(
'.'
)[
-
1
]
alias_to_name
[
name
]
=
'{}[{}]'
.
format
(
mod_name
,
attr_name
)
for
node
in
nodes_py
.
nodes_op
:
module_aliases
=
node
.
scopeName
.
split
(
'/'
)[
-
1
].
split
(
'.'
)
module_name
=
''
for
i
,
alias
in
enumerate
(
module_aliases
):
if
i
==
0
:
module_name
=
alias
node
.
scopeName
=
base_name
else
:
module_name
+=
'.'
+
alias
node
.
scopeName
+=
'/'
+
(
alias_to_name
[
module_name
]
if
module_name
in
alias_to_name
else
alias
)
nodes_py
.
populate_namespace_from_OP_to_IO
()
return
nodes_py
.
to_proto
()
def
graph
(
model
,
args
,
verbose
=
False
):
"""
This method processes a PyTorch model and produces a `GraphDef` proto
that can be logged to TensorBoard.
Args:
model (PyTorch module): The model to be parsed.
args (tuple): input tensor[s] for the model.
verbose (bool): Whether to print out verbose information while
processing.
"""
with
torch
.
onnx
.
set_training
(
model
,
False
):
# TODO: move outside of torch.onnx?
try
:
trace
=
torch
.
jit
.
trace
(
model
,
args
)
graph
=
trace
.
graph
torch
.
_C
.
_jit_pass_inline
(
graph
)
except
RuntimeError
as
e
:
print
(
e
)
print
(
'Error occurs, No graph saved'
)
raise
e
if
verbose
:
print
(
graph
)
list_of_nodes
=
parse
(
graph
,
trace
,
args
)
# We are hardcoding that this was run on CPU even though it might have actually
# run on GPU. Note this is what is shown in TensorBoard and has no bearing
# on actual execution.
# TODO: See if we can extract GPU vs CPU information from the PyTorch model
# and pass it correctly to TensorBoard.
#
# Definition of StepStats and DeviceStepStats can be found at
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
# and
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
stepstats
=
RunMetadata
(
step_stats
=
StepStats
(
dev_stats
=
[
DeviceStepStats
(
device
=
"/device:CPU:0"
)]))
return
GraphDef
(
node
=
list_of_nodes
,
versions
=
VersionDef
(
producer
=
22
)),
stepstats
# The producer version has been reverse engineered from standard
# TensorBoard logged data.
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
6568eaee
...
@@ -107,12 +107,12 @@ class Mutator(BaseMutator):
...
@@ -107,12 +107,12 @@ class Mutator(BaseMutator):
"""
"""
if
not
torch
.
__version__
.
startswith
(
"1.4"
):
if
not
torch
.
__version__
.
startswith
(
"1.4"
):
logger
.
warning
(
"Graph is only tested with PyTorch 1.4. Other versions might not work."
)
logger
.
warning
(
"Graph is only tested with PyTorch 1.4. Other versions might not work."
)
from
._graph_utils
import
graph
from
nni
._graph_utils
import
build_
graph
from
google.protobuf
import
json_format
from
google.protobuf
import
json_format
# protobuf should be installed as long as tensorboard is installed
# protobuf should be installed as long as tensorboard is installed
try
:
try
:
self
.
_connect_all
=
True
self
.
_connect_all
=
True
graph_def
,
_
=
graph
(
self
.
model
,
inputs
,
verbose
=
False
)
graph_def
,
_
=
build_
graph
(
self
.
model
,
inputs
,
verbose
=
False
)
result
=
json_format
.
MessageToDict
(
graph_def
)
result
=
json_format
.
MessageToDict
(
graph_def
)
finally
:
finally
:
self
.
_connect_all
=
False
self
.
_connect_all
=
False
...
...
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
View file @
6568eaee
...
@@ -55,7 +55,8 @@ class PdartsMutator(DartsMutator):
...
@@ -55,7 +55,8 @@ class PdartsMutator(DartsMutator):
del
module
[
index
]
del
module
[
index
]
assert
len
(
module
)
<=
len
(
choices
),
"Failed to remove dropped choices."
assert
len
(
module
)
<=
len
(
choices
),
"Failed to remove dropped choices."
def
sample_final
(
self
):
def
export
(
self
):
# Cannot rely on super().export() because P-DARTS has deleted some of the choices and has misaligned length.
results
=
super
().
sample_final
()
results
=
super
().
sample_final
()
for
mutable
in
self
.
mutables
:
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
if
isinstance
(
mutable
,
LayerChoice
):
...
...
src/sdk/pynni/tests/expect/test_graph_module1.expect
0 → 100644
View file @
6568eaee
node {
name: "input/input"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "myLinear/Linear[l]/22"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "myLinear/Linear[l]/bias/17"
op: "prim::GetAttr"
input: "myLinear/Linear[l]/weight/14"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "myLinear/Linear[l]/weight/18"
op: "prim::GetAttr"
input: "myLinear/Linear[l]/weight/14"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "myLinear/Linear[l]/19"
op: "aten::t"
input: "myLinear/Linear[l]/weight/18"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 3
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "myLinear/Linear[l]/20"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/21"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/22"
op: "aten::addmm"
input: "myLinear/Linear[l]/bias/17"
input: "input/input"
input: "myLinear/Linear[l]/19"
input: "myLinear/Linear[l]/20"
input: "myLinear/Linear[l]/21"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}
src/sdk/pynni/tests/expect/test_graph_module2.expect
0 → 100644
View file @
6568eaee
node {
name: "input/input.1"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "input/input.1"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "MyModule/Linear[weight]/bias/49"
op: "prim::GetAttr"
input: "MyModule/Linear[weight]/weight/35"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/Linear[weight]/weight/50"
op: "prim::GetAttr"
input: "MyModule/Linear[weight]/weight/35"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/Linear[weight]/51"
op: "aten::t"
input: "MyModule/Linear[weight]/weight/50"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/Linear[weight]/52"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[weight]/53"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[weight]/54"
op: "aten::addmm"
input: "MyModule/Linear[weight]/bias/49"
input: "input/input.1"
input: "MyModule/Linear[weight]/51"
input: "MyModule/Linear[weight]/52"
input: "MyModule/Linear[weight]/53"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/Linear[bias]/bias/55"
op: "prim::GetAttr"
input: "MyModule/Linear[bias]/weight/38"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/Linear[bias]/weight/56"
op: "prim::GetAttr"
input: "MyModule/Linear[bias]/weight/38"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/Linear[bias]/57"
op: "aten::t"
input: "MyModule/Linear[bias]/weight/56"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/Linear[bias]/58"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[bias]/59"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[bias]/60"
op: "aten::addmm"
input: "MyModule/Linear[bias]/bias/55"
input: "input/input.1"
input: "MyModule/Linear[bias]/57"
input: "MyModule/Linear[bias]/58"
input: "MyModule/Linear[bias]/59"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/23"
op: "prim::ListConstruct"
input: "MyModule/Linear[weight]/54"
input: "MyModule/Linear[bias]/60"
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/24"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/input"
op: "aten::cat"
input: "MyModule/23"
input: "MyModule/24"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 6
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/61"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}
src/sdk/pynni/tests/expect/test_graph_module3.expect
0 → 100644
View file @
6568eaee
node {
name: "input/input.1"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "MyModule/ModuleList[module]/Linear[1]/46"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 1
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/bias/35"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[0]/weight/26"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/weight/36"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[0]/weight/26"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/37"
op: "aten::t"
input: "MyModule/ModuleList[module]/Linear[0]/weight/36"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/38"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/39"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/input"
op: "aten::addmm"
input: "MyModule/ModuleList[module]/Linear[0]/bias/35"
input: "input/input.1"
input: "MyModule/ModuleList[module]/Linear[0]/37"
input: "MyModule/ModuleList[module]/Linear[0]/38"
input: "MyModule/ModuleList[module]/Linear[0]/39"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/bias/41"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[1]/weight/30"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/weight/42"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[1]/weight/30"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/43"
op: "aten::t"
input: "MyModule/ModuleList[module]/Linear[1]/weight/42"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 3
}
dim {
size: 1
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/44"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/45"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/46"
op: "aten::addmm"
input: "MyModule/ModuleList[module]/Linear[1]/bias/41"
input: "MyModule/ModuleList[module]/Linear[0]/input"
input: "MyModule/ModuleList[module]/Linear[1]/43"
input: "MyModule/ModuleList[module]/Linear[1]/44"
input: "MyModule/ModuleList[module]/Linear[1]/45"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 1
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}
src/sdk/pynni/tests/test_graph_utils.py
0 → 100644
View file @
6568eaee
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
sys
import
os
import
math
import
uuid
import
shutil
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
tensorboard.compat.proto.graph_pb2
import
GraphDef
from
google.protobuf
import
text_format
import
unittest
from
unittest
import
TestCase
,
main
from
nni._graph_utils
import
build_module_graph
,
build_graph
class
BackboneModel1
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
1
,
1
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv1
(
x
)
class
BackboneModel2
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
self
.
conv1
.
out_channels
)
self
.
bn2
=
nn
.
BatchNorm2d
(
self
.
conv2
.
out_channels
)
self
.
fc1
=
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
nn
.
Linear
(
500
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
x
class
BigModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
backbone1
=
BackboneModel1
()
self
.
backbone2
=
BackboneModel2
()
self
.
fc3
=
nn
.
Linear
(
10
,
2
)
def
forward
(
self
,
x
):
x
=
self
.
backbone1
(
x
)
x
=
self
.
backbone2
(
x
)
x
=
self
.
fc3
(
x
)
return
x
class
GraphUtilsTestCase
(
TestCase
):
def
test_build_module_graph
(
self
):
big_model
=
BigModel
()
g
=
build_module_graph
(
big_model
,
torch
.
randn
(
2
,
1
,
28
,
28
))
print
(
g
.
name_to_node
.
keys
())
leaf_modules
=
set
([
'backbone1.conv1'
,
'backbone2.bn1'
,
'backbone2.bn2'
,
'backbone2.conv1'
,
'backbone2.conv2'
,
'backbone2.fc1'
,
'backbone2.fc2'
,
'fc3'
])
assert
set
(
g
.
leaf_modules
)
==
leaf_modules
assert
not
leaf_modules
-
set
(
g
.
name_to_node
.
keys
())
assert
g
.
find_successors
(
'backbone2.conv1'
)
==
[
'backbone2.bn1'
]
assert
g
.
find_successors
(
'backbone2.conv2'
)
==
[
'backbone2.bn2'
]
assert
g
.
find_predecessors
(
'backbone2.bn1'
)
==
[
'backbone2.conv1'
]
assert
g
.
find_predecessors
(
'backbone2.bn2'
)
==
[
'backbone2.conv2'
]
def
_test_graph
(
self
,
model
,
dummy_input
,
expected_file
):
actual_proto
,
_
=
build_graph
(
model
,
dummy_input
)
assert
os
.
path
.
exists
(
expected_file
),
expected_file
with
open
(
expected_file
,
"r"
)
as
f
:
expected_str
=
f
.
read
()
expected_proto
=
GraphDef
()
text_format
.
Parse
(
expected_str
,
expected_proto
)
self
.
assertEquals
(
len
(
expected_proto
.
node
),
len
(
actual_proto
.
node
))
for
i
in
range
(
len
(
expected_proto
.
node
)):
expected_node
=
expected_proto
.
node
[
i
]
actual_node
=
actual_proto
.
node
[
i
]
self
.
assertEquals
(
expected_node
.
name
,
actual_node
.
name
)
self
.
assertEquals
(
expected_node
.
op
,
actual_node
.
op
)
self
.
assertEquals
(
expected_node
.
input
,
actual_node
.
input
)
self
.
assertEquals
(
expected_node
.
device
,
actual_node
.
device
)
self
.
assertEquals
(
sorted
(
expected_node
.
attr
.
keys
()),
sorted
(
actual_node
.
attr
.
keys
()))
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.4.0"
,
"not supported"
)
def
test_graph_module1
(
self
):
dummy_input
=
(
torch
.
zeros
(
1
,
3
),)
class
myLinear
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
(
myLinear
,
self
).
__init__
()
self
.
l
=
torch
.
nn
.
Linear
(
3
,
5
)
def
forward
(
self
,
x
):
return
self
.
l
(
x
)
self
.
_test_graph
(
myLinear
(),
dummy_input
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"expect"
,
"test_graph_module1.expect"
)
)
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.4.0"
,
"not supported"
)
def
test_graph_module2
(
self
):
class
MyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
weight
=
nn
.
Linear
(
5
,
3
)
self
.
bias
=
nn
.
Linear
(
5
,
3
)
self
.
module
=
nn
.
Linear
(
6
,
1
)
def
forward
(
self
,
x
):
tensors
=
[
self
.
weight
(
x
),
self
.
bias
(
x
)]
self
.
module
(
torch
.
cat
(
tensors
,
dim
=
1
))
return
x
self
.
_test_graph
(
MyModule
(),
torch
.
randn
(
4
,
5
),
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"expect"
,
"test_graph_module2.expect"
)
)
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.4.0"
,
"not supported"
)
def
test_graph_module3
(
self
):
class
MyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
module
=
nn
.
ModuleList
([
nn
.
Linear
(
5
,
3
),
nn
.
Linear
(
3
,
1
)
])
def
forward
(
self
,
x
):
x
=
self
.
module
[
0
](
x
)
x
=
self
.
module
[
1
](
x
)
return
x
self
.
_test_graph
(
MyModule
(),
torch
.
randn
(
4
,
5
),
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"expect"
,
"test_graph_module3.expect"
)
)
if
__name__
==
'__main__'
:
main
()
src/sdk/pynni/tests/test_model_speedup.py
0 → 100644
View file @
6568eaee
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torchvision.models.vgg
import
vgg16
from
torchvision.models.resnet
import
resnet18
from
unittest
import
TestCase
,
main
from
nni.compression.torch
import
L1FilterPruner
from
nni.compression.speedup.torch
import
ModelSpeedup
class
BackboneModel1
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
1
,
1
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv1
(
x
)
class
BackboneModel2
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
self
.
conv1
.
out_channels
)
self
.
bn2
=
nn
.
BatchNorm2d
(
self
.
conv2
.
out_channels
)
self
.
fc1
=
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
nn
.
Linear
(
500
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
x
class
BigModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
backbone1
=
BackboneModel1
()
self
.
backbone2
=
BackboneModel2
()
self
.
fc3
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
BatchNorm1d
(
10
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
10
,
2
)
)
def
forward
(
self
,
x
):
x
=
self
.
backbone1
(
x
)
x
=
self
.
backbone2
(
x
)
x
=
self
.
fc3
(
x
)
return
x
SPARSITY
=
0.5
def
prune_model_l1
(
model
):
config_list
=
[{
'sparsity'
:
SPARSITY
,
'op_types'
:
[
'Conv2d'
]
}]
pruner
=
L1FilterPruner
(
model
,
config_list
)
pruner
.
compress
()
pruner
.
export_model
(
model_path
=
'./11_model.pth'
,
mask_path
=
'./l1_mask.pth'
)
class
SpeedupTestCase
(
TestCase
):
def
test_speedup_vgg16
(
self
):
prune_model_l1
(
vgg16
())
model
=
vgg16
()
model
.
train
()
ms
=
ModelSpeedup
(
model
,
torch
.
randn
(
2
,
3
,
32
,
32
),
'./l1_mask.pth'
)
ms
.
speedup_model
()
orig_model
=
vgg16
()
assert
model
.
training
assert
model
.
features
[
2
].
out_channels
==
int
(
orig_model
.
features
[
2
].
out_channels
*
SPARSITY
)
assert
model
.
classifier
[
0
].
in_features
==
int
(
orig_model
.
classifier
[
0
].
in_features
*
SPARSITY
)
#def test_speedup_resnet(self):
#TODO support resnet
#model = resnet18()
def
test_speedup_bigmodel
(
self
):
prune_model_l1
(
BigModel
())
model
=
BigModel
()
model
.
train
()
ms
=
ModelSpeedup
(
model
,
torch
.
randn
(
2
,
1
,
28
,
28
),
'./l1_mask.pth'
)
ms
.
speedup_model
()
orig_model
=
BigModel
()
assert
model
.
training
assert
model
.
backbone2
.
conv1
.
out_channels
==
int
(
orig_model
.
backbone2
.
conv1
.
out_channels
*
SPARSITY
)
assert
model
.
backbone2
.
conv2
.
in_channels
==
int
(
orig_model
.
backbone2
.
conv2
.
in_channels
*
SPARSITY
)
assert
model
.
backbone2
.
conv2
.
out_channels
==
int
(
orig_model
.
backbone2
.
conv2
.
out_channels
*
SPARSITY
)
assert
model
.
backbone2
.
fc1
.
in_features
==
int
(
orig_model
.
backbone2
.
fc1
.
in_features
*
SPARSITY
)
def
tearDown
(
self
):
os
.
remove
(
'./11_model.pth'
)
os
.
remove
(
'./l1_mask.pth'
)
if
__name__
==
'__main__'
:
main
()
src/webui/src/App.tsx
View file @
6568eaee
...
@@ -49,8 +49,8 @@ class App extends React.Component<{}, AppState> {
...
@@ -49,8 +49,8 @@ class App extends React.Component<{}, AppState> {
}
}
getFinalDataFormat
=
():
void
=>
{
getFinalDataFormat
=
():
void
=>
{
for
(
let
i
=
0
;
this
.
state
.
isillegalFinal
===
false
;
i
++
){
for
(
let
i
=
0
;
this
.
state
.
isillegalFinal
===
false
;
i
++
)
{
if
(
TRIALS
.
succeededTrials
()[
0
]
!==
undefined
&&
TRIALS
.
succeededTrials
()[
0
].
final
!==
undefined
){
if
(
TRIALS
.
succeededTrials
()[
0
]
!==
undefined
&&
TRIALS
.
succeededTrials
()[
0
].
final
!==
undefined
)
{
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const
oneSucceedTrial
=
JSON
.
parse
(
JSON
.
parse
(
TRIALS
.
succeededTrials
()[
0
].
final
!
.
data
));
const
oneSucceedTrial
=
JSON
.
parse
(
JSON
.
parse
(
TRIALS
.
succeededTrials
()[
0
].
final
!
.
data
));
if
(
typeof
oneSucceedTrial
===
'
number
'
||
oneSucceedTrial
.
hasOwnProperty
(
'
default
'
))
{
if
(
typeof
oneSucceedTrial
===
'
number
'
||
oneSucceedTrial
.
hasOwnProperty
(
'
default
'
))
{
...
@@ -78,7 +78,7 @@ class App extends React.Component<{}, AppState> {
...
@@ -78,7 +78,7 @@ class App extends React.Component<{}, AppState> {
}
}
// setState will trigger page refresh at once.
// setState will trigger page refresh at once.
// setState is asyc, interval not update to (this.state.interval) at once.
// setState is asyc, interval not update to (this.state.interval) at once.
this
.
setState
({
interval
},
()
=>
{
this
.
setState
({
interval
},
()
=>
{
this
.
firstLoad
=
true
;
this
.
firstLoad
=
true
;
this
.
refresh
();
this
.
refresh
();
});
});
...
@@ -96,7 +96,7 @@ class App extends React.Component<{}, AppState> {
...
@@ -96,7 +96,7 @@ class App extends React.Component<{}, AppState> {
// overview best trial module
// overview best trial module
changeEntries
=
(
entries
:
string
):
void
=>
{
changeEntries
=
(
entries
:
string
):
void
=>
{
this
.
setState
({
bestTrialEntries
:
entries
});
this
.
setState
({
bestTrialEntries
:
entries
});
}
}
render
():
React
.
ReactNode
{
render
():
React
.
ReactNode
{
...
@@ -106,6 +106,16 @@ class App extends React.Component<{}, AppState> {
...
@@ -106,6 +106,16 @@ class App extends React.Component<{}, AppState> {
if
(
experimentUpdateBroadcast
===
0
||
trialsUpdateBroadcast
===
0
)
{
if
(
experimentUpdateBroadcast
===
0
||
trialsUpdateBroadcast
===
0
)
{
return
null
;
// TODO: render a loading page
return
null
;
// TODO: render a loading page
}
}
const
errorList
=
[
{
errorWhere
:
TRIALS
.
jobListError
(),
errorMessage
:
TRIALS
.
getJobErrorMessage
()
},
{
errorWhere
:
EXPERIMENT
.
experimentError
(),
errorMessage
:
EXPERIMENT
.
getExperimentMessage
()
},
{
errorWhere
:
EXPERIMENT
.
statusError
(),
errorMessage
:
EXPERIMENT
.
getStatusMessage
()
},
{
errorWhere
:
TRIALS
.
MetricDataError
(),
errorMessage
:
TRIALS
.
getMetricDataErrorMessage
()
},
{
errorWhere
:
TRIALS
.
latestMetricDataError
(),
errorMessage
:
TRIALS
.
getLatestMetricDataErrorMessage
()
},
{
errorWhere
:
TRIALS
.
metricDataRangeError
(),
errorMessage
:
TRIALS
.
metricDataRangeErrorMessage
()
}
];
const
reactPropsChildren
=
React
.
Children
.
map
(
this
.
props
.
children
,
child
=>
const
reactPropsChildren
=
React
.
Children
.
map
(
this
.
props
.
children
,
child
=>
React
.
cloneElement
(
React
.
cloneElement
(
child
as
React
.
ReactElement
<
any
>
,
{
child
as
React
.
ReactElement
<
any
>
,
{
...
@@ -127,6 +137,16 @@ class App extends React.Component<{}, AppState> {
...
@@ -127,6 +137,16 @@ class App extends React.Component<{}, AppState> {
</
div
>
</
div
>
<
Stack
className
=
"contentBox"
>
<
Stack
className
=
"contentBox"
>
<
Stack
className
=
"content"
>
<
Stack
className
=
"content"
>
{
/* if api has error field, show error message */
}
{
errorList
.
map
((
item
,
key
)
=>
{
return
(
item
.
errorWhere
&&
<
div
key
=
{
key
}
className
=
"warning"
>
<
MessageInfo
info
=
{
item
.
errorMessage
}
typeInfo
=
"error"
/>
</
div
>
);
})
}
{
isillegalFinal
&&
<
div
className
=
"warning"
>
{
isillegalFinal
&&
<
div
className
=
"warning"
>
<
MessageInfo
info
=
{
expWarningMessage
}
typeInfo
=
"warning"
/>
<
MessageInfo
info
=
{
expWarningMessage
}
typeInfo
=
"warning"
/>
</
div
>
}
</
div
>
}
...
@@ -149,11 +169,13 @@ class App extends React.Component<{}, AppState> {
...
@@ -149,11 +169,13 @@ class App extends React.Component<{}, AppState> {
if
(
trialsUpdated
)
{
if
(
trialsUpdated
)
{
this
.
setState
(
state
=>
({
trialsUpdateBroadcast
:
state
.
trialsUpdateBroadcast
+
1
}));
this
.
setState
(
state
=>
({
trialsUpdateBroadcast
:
state
.
trialsUpdateBroadcast
+
1
}));
}
}
}
else
{
}
else
{
this
.
firstLoad
=
false
;
this
.
firstLoad
=
false
;
}
}
if
([
'
DONE
'
,
'
ERROR
'
,
'
STOPPED
'
].
includes
(
EXPERIMENT
.
status
))
{
// experiment status and /trial-jobs api's status could decide website update
if
([
'
DONE
'
,
'
ERROR
'
,
'
STOPPED
'
].
includes
(
EXPERIMENT
.
status
)
||
TRIALS
.
jobListError
())
{
// experiment finished, refresh once more to ensure consistency
// experiment finished, refresh once more to ensure consistency
this
.
setState
({
interval
:
0
});
this
.
setState
({
interval
:
0
});
this
.
lastRefresh
();
this
.
lastRefresh
();
...
...
src/webui/src/components/Modals/MessageInfo.tsx
View file @
6568eaee
...
@@ -18,7 +18,7 @@ class MessageInfo extends React.Component<MessageInfoProps, {}> {
...
@@ -18,7 +18,7 @@ class MessageInfo extends React.Component<MessageInfoProps, {}> {
return
(
return
(
<
MessageBar
<
MessageBar
messageBarType
=
{
MessageBarType
[
typeInfo
]
}
messageBarType
=
{
MessageBarType
[
typeInfo
]
}
isMultiline
=
{
fals
e
}
isMultiline
=
{
tru
e
}
className
=
{
className
}
className
=
{
className
}
>
>
{
info
}
{
info
}
...
...
src/webui/src/components/overview/ProgressItem.tsx
View file @
6568eaee
...
@@ -22,7 +22,7 @@ class ProgressBar extends React.Component<ProItemProps, {}> {
...
@@ -22,7 +22,7 @@ class ProgressBar extends React.Component<ProItemProps, {}> {
<
div
>
<
div
>
<
Stack
horizontal
className
=
{
`probar
${
bgclass
}
`
}
>
<
Stack
horizontal
className
=
{
`probar
${
bgclass
}
`
}
>
<
div
className
=
"name"
>
{
who
}
</
div
>
<
div
className
=
"name"
>
{
who
}
</
div
>
<
div
className
=
"showProgress"
style
=
{
{
width
:
'
8
0
%
'
}
}
>
<
div
className
=
"showProgress"
style
=
{
{
width
:
'
7
8%
'
}
}
>
<
ProgressIndicator
<
ProgressIndicator
barHeight
=
{
30
}
barHeight
=
{
30
}
percentComplete
=
{
percent
}
percentComplete
=
{
percent
}
...
@@ -32,7 +32,7 @@ class ProgressBar extends React.Component<ProItemProps, {}> {
...
@@ -32,7 +32,7 @@ class ProgressBar extends React.Component<ProItemProps, {}> {
<
StackItem
className
=
"right"
grow
=
{
70
}
>
{
maxString
}
</
StackItem
>
<
StackItem
className
=
"right"
grow
=
{
70
}
>
{
maxString
}
</
StackItem
>
</
Stack
>
</
Stack
>
</
div
>
</
div
>
<
div
className
=
"description"
style
=
{
{
width
:
'
2
0
%
'
}
}
>
{
description
}
</
div
>
<
div
className
=
"description"
style
=
{
{
width
:
'
2
2
%
'
}
}
>
{
description
}
</
div
>
</
Stack
>
</
Stack
>
<
br
/>
<
br
/>
</
div
>
</
div
>
...
...
src/webui/src/components/trial-detail/Para.tsx
View file @
6568eaee
import
*
as
React
from
'
react
'
;
import
*
as
React
from
'
react
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
{
filterByStatus
}
from
'
../../static/function
'
;
import
{
filterByStatus
}
from
'
../../static/function
'
;
import
{
EXPERIMENT
}
from
'
../../static/datamodel
'
;
import
{
Stack
,
PrimaryButton
,
Dropdown
,
IDropdownOption
,
}
from
'
office-ui-fabric-react
'
;
// eslint-disable-line no-unused-vars
import
{
Stack
,
PrimaryButton
,
Dropdown
,
IDropdownOption
,
}
from
'
office-ui-fabric-react
'
;
// eslint-disable-line no-unused-vars
import
{
ParaObj
,
Dimobj
,
TableObj
}
from
'
../../static/interface
'
;
// eslint-disable-line no-unused-vars
import
{
ParaObj
,
Dimobj
,
TableObj
}
from
'
../../static/interface
'
;
// eslint-disable-line no-unused-vars
import
'
echarts/lib/chart/parallel
'
;
import
'
echarts/lib/chart/parallel
'
;
...
@@ -98,8 +99,14 @@ class Para extends React.Component<ParaProps, ParaState> {
...
@@ -98,8 +99,14 @@ class Para extends React.Component<ParaProps, ParaState> {
// according acc to sort ydata // sort to find top percent dataset
// according acc to sort ydata // sort to find top percent dataset
if
(
paraYdata
.
length
!==
0
)
{
if
(
paraYdata
.
length
!==
0
)
{
const
len
=
paraYdata
[
0
].
length
-
1
;
const
len
=
paraYdata
[
0
].
length
-
1
;
// show top trials
if
(
EXPERIMENT
.
optimizeMode
===
'
minimize
'
)
{
paraYdata
.
sort
((
a
,
b
)
=>
a
[
len
]
-
b
[
len
]);
}
if
(
EXPERIMENT
.
optimizeMode
===
'
maximize
'
)
{
paraYdata
.
sort
((
a
,
b
)
=>
b
[
len
]
-
a
[
len
]);
paraYdata
.
sort
((
a
,
b
)
=>
b
[
len
]
-
a
[
len
]);
}
}
}
const
paraData
=
{
const
paraData
=
{
parallelAxis
:
parallelAxis
,
parallelAxis
:
parallelAxis
,
data
:
paraYdata
data
:
paraYdata
...
...
src/webui/src/static/function.ts
View file @
6568eaee
...
@@ -3,6 +3,19 @@ import axios from 'axios';
...
@@ -3,6 +3,19 @@ import axios from 'axios';
import
{
MANAGER_IP
}
from
'
./const
'
;
import
{
MANAGER_IP
}
from
'
./const
'
;
import
{
MetricDataRecord
,
FinalType
,
TableObj
}
from
'
./interface
'
;
import
{
MetricDataRecord
,
FinalType
,
TableObj
}
from
'
./interface
'
;
async
function
requestAxios
(
url
:
string
)
{
const
response
=
await
axios
.
get
(
url
);
if
(
response
.
status
===
200
)
{
if
(
response
.
data
.
error
!==
undefined
)
{
throw
new
Error
(
`API
${
url
}
${
response
.
data
.
error
}
`
);
}
else
{
return
response
.
data
as
any
;
}
}
else
{
throw
new
Error
(
`API
${
url
}
${
response
.
status
}
error`
);
}
}
const
convertTime
=
(
num
:
number
):
string
=>
{
const
convertTime
=
(
num
:
number
):
string
=>
{
if
(
num
<=
0
)
{
if
(
num
<=
0
)
{
return
'
0
'
;
return
'
0
'
;
...
@@ -219,5 +232,5 @@ export {
...
@@ -219,5 +232,5 @@ export {
convertTime
,
convertDuration
,
getFinalResult
,
getFinal
,
downFile
,
convertTime
,
convertDuration
,
getFinalResult
,
getFinal
,
downFile
,
intermediateGraphOption
,
killJob
,
filterByStatus
,
filterDuration
,
intermediateGraphOption
,
killJob
,
filterByStatus
,
filterDuration
,
formatAccuracy
,
formatTimestamp
,
metricAccuracy
,
parseMetrics
,
formatAccuracy
,
formatTimestamp
,
metricAccuracy
,
parseMetrics
,
isArrayType
isArrayType
,
requestAxios
};
};
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment