Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bf8be1e7
"vscode:/vscode.git/clone" did not exist on "c6c361d80ada8117e926bd24f71f50bb5da9f0b3"
Unverified
Commit
bf8be1e7
authored
Aug 28, 2020
by
Yuge Zhang
Committed by
GitHub
Aug 28, 2020
Browse files
Merge pull request #2837 from microsoft/v1.8
Merge v1.8 back to master
parents
320407b1
e06a9dda
Changes
44
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
305 additions
and
127 deletions
+305
-127
src/nni_manager/training_service/local/localTrainingService.ts
...ni_manager/training_service/local/localTrainingService.ts
+3
-1
src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
...ng_service/remote_machine/remoteMachineTrainingService.ts
+25
-23
src/nni_manager/training_service/reusable/aml/amlClient.ts
src/nni_manager/training_service/reusable/aml/amlClient.ts
+21
-15
src/nni_manager/training_service/reusable/test/amlClient.test.ts
..._manager/training_service/reusable/test/amlClient.test.ts
+29
-0
src/sdk/pynni/nni/compression/tensorflow/compressor.py
src/sdk/pynni/nni/compression/tensorflow/compressor.py
+38
-31
src/sdk/pynni/nni/compression/tensorflow/pruning/one_shot.py
src/sdk/pynni/nni/compression/tensorflow/pruning/one_shot.py
+10
-6
src/sdk/pynni/nni/compression/torch/pruning/sensitivity_pruner.py
...pynni/nni/compression/torch/pruning/sensitivity_pruner.py
+18
-2
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
+5
-1
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
+4
-1
src/sdk/pynni/nni/compression/torch/utils/sensitivity_analysis.py
...pynni/nni/compression/torch/utils/sensitivity_analysis.py
+7
-5
src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py
src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py
+1
-1
src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py
src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py
+1
-1
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+2
-2
src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
+2
-3
src/sdk/pynni/nni/package_utils.py
src/sdk/pynni/nni/package_utils.py
+14
-26
src/sdk/pynni/tests/test_compressor_tf.py
src/sdk/pynni/tests/test_compressor_tf.py
+97
-0
src/sdk/pynni/tests/test_model_speedup.py
src/sdk/pynni/tests/test_model_speedup.py
+1
-1
src/webui/src/components/Modals/Compare.tsx
src/webui/src/components/Modals/Compare.tsx
+18
-3
src/webui/src/components/public-child/PaiTrialChild.tsx
src/webui/src/components/public-child/PaiTrialChild.tsx
+7
-3
src/webui/src/components/public-child/PaiTrialLog.tsx
src/webui/src/components/public-child/PaiTrialLog.tsx
+2
-2
No files found.
src/nni_manager/training_service/local/localTrainingService.ts
View file @
bf8be1e7
...
@@ -491,7 +491,7 @@ class LocalTrainingService implements TrainingService {
...
@@ -491,7 +491,7 @@ class LocalTrainingService implements TrainingService {
if
(
process
.
platform
===
'
win32
'
)
{
if
(
process
.
platform
===
'
win32
'
)
{
script
.
push
(
`cd $env:NNI_CODE_DIR`
);
script
.
push
(
`cd $env:NNI_CODE_DIR`
);
script
.
push
(
script
.
push
(
`cmd.exe /c
${
localTrialConfig
.
command
}
2>"
${
path
.
join
(
workingDirectory
,
'
stderr
'
)}
"`
,
`cmd.exe /c
${
localTrialConfig
.
command
}
2>
&1 | Out-File
"
${
path
.
join
(
workingDirectory
,
'
stderr
'
)}
"
-encoding utf8
`
,
`$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`
,
`$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`
,
`$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`
,
`$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`
,
`Write $LASTEXITCODE " " $NOW_DATE | Out-File "
${
path
.
join
(
workingDirectory
,
'
.nni
'
,
'
state
'
)}
" -NoNewline -encoding utf8`
);
`Write $LASTEXITCODE " " $NOW_DATE | Out-File "
${
path
.
join
(
workingDirectory
,
'
.nni
'
,
'
state
'
)}
" -NoNewline -encoding utf8`
);
...
@@ -523,6 +523,8 @@ class LocalTrainingService implements TrainingService {
...
@@ -523,6 +523,8 @@ class LocalTrainingService implements TrainingService {
const
runScriptContent
:
string
[]
=
[];
const
runScriptContent
:
string
[]
=
[];
if
(
process
.
platform
!==
'
win32
'
)
{
if
(
process
.
platform
!==
'
win32
'
)
{
runScriptContent
.
push
(
'
#!/bin/bash
'
);
runScriptContent
.
push
(
'
#!/bin/bash
'
);
}
else
{
runScriptContent
.
push
(
`$env:PATH="
${
process
.
env
.
path
}
"`
)
}
}
for
(
const
variable
of
variables
)
{
for
(
const
variable
of
variables
)
{
runScriptContent
.
push
(
setEnvironmentVariable
(
variable
));
runScriptContent
.
push
(
setEnvironmentVariable
(
variable
));
...
...
src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
View file @
bf8be1e7
...
@@ -87,6 +87,21 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -87,6 +87,21 @@ class RemoteMachineTrainingService implements TrainingService {
this
.
log
.
info
(
'
ssh connection initialized!
'
);
this
.
log
.
info
(
'
ssh connection initialized!
'
);
// set sshConnectionPromises to [] to avoid log information duplicated
// set sshConnectionPromises to [] to avoid log information duplicated
this
.
sshConnectionPromises
=
[];
this
.
sshConnectionPromises
=
[];
// initialize gpuScheduler
this
.
gpuScheduler
=
new
GPUScheduler
(
this
.
machineExecutorManagerMap
);
if
(
this
.
trialConfig
===
undefined
)
{
throw
new
Error
(
"
trial config not initialized!
"
);
}
// Copy codeDir to remote machine
for
(
const
[
rmMeta
,
executorManager
]
of
this
.
machineExecutorManagerMap
.
entries
())
{
const
executor
:
ShellExecutor
=
await
executorManager
.
getExecutor
(
this
.
initExecutorId
);
if
(
executor
!==
undefined
)
{
this
.
machineCopyExpCodeDirPromiseMap
.
set
(
rmMeta
,
executor
.
copyDirectoryToRemote
(
this
.
trialConfig
.
codeDir
,
executor
.
getRemoteCodePath
(
getExperimentId
()))
);
}
}
}
}
while
(
!
this
.
stopping
)
{
while
(
!
this
.
stopping
)
{
while
(
this
.
jobQueue
.
length
>
0
)
{
while
(
this
.
jobQueue
.
length
>
0
)
{
...
@@ -310,7 +325,6 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -310,7 +325,6 @@ class RemoteMachineTrainingService implements TrainingService {
break
;
break
;
case
TrialConfigMetadataKey
.
MACHINE_LIST
:
case
TrialConfigMetadataKey
.
MACHINE_LIST
:
await
this
.
setupConnections
(
value
);
await
this
.
setupConnections
(
value
);
this
.
gpuScheduler
=
new
GPUScheduler
(
this
.
machineExecutorManagerMap
);
break
;
break
;
case
TrialConfigMetadataKey
.
TRIAL_CONFIG
:
{
case
TrialConfigMetadataKey
.
TRIAL_CONFIG
:
{
const
remoteMachineTrailConfig
:
TrialConfig
=
<
TrialConfig
>
JSON
.
parse
(
value
);
const
remoteMachineTrailConfig
:
TrialConfig
=
<
TrialConfig
>
JSON
.
parse
(
value
);
...
@@ -327,20 +341,8 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -327,20 +341,8 @@ class RemoteMachineTrainingService implements TrainingService {
try
{
try
{
// Validate to make sure codeDir doesn't have too many files
// Validate to make sure codeDir doesn't have too many files
await
validateCodeDir
(
remoteMachineTrailConfig
.
codeDir
);
await
validateCodeDir
(
remoteMachineTrailConfig
.
codeDir
);
// Copy codeDir to remote machine
for
(
const
[
rmMeta
,
executorManager
]
of
this
.
machineExecutorManagerMap
.
entries
())
{
const
executor
:
ShellExecutor
=
await
executorManager
.
getExecutor
(
this
.
initExecutorId
);
if
(
executor
!==
undefined
)
{
this
.
machineCopyExpCodeDirPromiseMap
.
set
(
rmMeta
,
executor
.
copyDirectoryToRemote
(
remoteMachineTrailConfig
.
codeDir
,
executor
.
getRemoteCodePath
(
getExperimentId
()))
);
}
}
}
catch
(
error
)
{
}
catch
(
error
)
{
this
.
log
.
error
(
error
);
this
.
log
.
error
(
error
);
return
Promise
.
reject
(
new
Error
(
error
));
return
Promise
.
reject
(
new
Error
(
error
));
}
}
...
@@ -426,6 +428,11 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -426,6 +428,11 @@ class RemoteMachineTrainingService implements TrainingService {
const
rmMetaList
:
RemoteMachineMeta
[]
=
<
RemoteMachineMeta
[]
>
JSON
.
parse
(
machineList
);
const
rmMetaList
:
RemoteMachineMeta
[]
=
<
RemoteMachineMeta
[]
>
JSON
.
parse
(
machineList
);
for
(
const
rmMeta
of
rmMetaList
)
{
for
(
const
rmMeta
of
rmMetaList
)
{
this
.
sshConnectionPromises
.
push
(
this
.
initRemoteMachineOnConnected
(
rmMeta
));
}
}
private
async
initRemoteMachineOnConnected
(
rmMeta
:
RemoteMachineMeta
):
Promise
<
void
>
{
rmMeta
.
occupiedGpuIndexMap
=
new
Map
<
number
,
number
>
();
rmMeta
.
occupiedGpuIndexMap
=
new
Map
<
number
,
number
>
();
const
executorManager
:
ExecutorManager
=
new
ExecutorManager
(
rmMeta
);
const
executorManager
:
ExecutorManager
=
new
ExecutorManager
(
rmMeta
);
this
.
log
.
info
(
`connecting to
${
rmMeta
.
username
}
@
${
rmMeta
.
ip
}
:
${
rmMeta
.
port
}
`
);
this
.
log
.
info
(
`connecting to
${
rmMeta
.
username
}
@
${
rmMeta
.
ip
}
:
${
rmMeta
.
port
}
`
);
...
@@ -433,12 +440,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -433,12 +440,7 @@ class RemoteMachineTrainingService implements TrainingService {
this
.
log
.
debug
(
`reached
${
executor
.
name
}
`
);
this
.
log
.
debug
(
`reached
${
executor
.
name
}
`
);
this
.
machineExecutorManagerMap
.
set
(
rmMeta
,
executorManager
);
this
.
machineExecutorManagerMap
.
set
(
rmMeta
,
executorManager
);
this
.
log
.
debug
(
`initializing
${
executor
.
name
}
`
);
this
.
log
.
debug
(
`initializing
${
executor
.
name
}
`
);
this
.
sshConnectionPromises
.
push
(
this
.
initRemoteMachineOnConnected
(
rmMeta
,
executor
));
this
.
log
.
info
(
`connecting to
${
executor
.
name
}
`
);
}
}
private
async
initRemoteMachineOnConnected
(
rmMeta
:
RemoteMachineMeta
,
executor
:
ShellExecutor
):
Promise
<
void
>
{
// Create root working directory after executor is ready
// Create root working directory after executor is ready
const
nniRootDir
:
string
=
executor
.
joinPath
(
executor
.
getTempPath
(),
'
nni
'
);
const
nniRootDir
:
string
=
executor
.
joinPath
(
executor
.
getTempPath
(),
'
nni
'
);
await
executor
.
createFolder
(
executor
.
getRemoteExperimentRootDir
(
getExperimentId
()));
await
executor
.
createFolder
(
executor
.
getRemoteExperimentRootDir
(
getExperimentId
()));
...
...
src/nni_manager/training_service/reusable/aml/amlClient.ts
View file @
bf8be1e7
...
@@ -74,13 +74,11 @@ export class AMLClient {
...
@@ -74,13 +74,11 @@ export class AMLClient {
throw
Error
(
'
python shell client not initialized!
'
);
throw
Error
(
'
python shell client not initialized!
'
);
}
}
this
.
pythonShellClient
.
send
(
'
tracking_url
'
);
this
.
pythonShellClient
.
send
(
'
tracking_url
'
);
let
trackingUrl
=
''
;
this
.
pythonShellClient
.
on
(
'
message
'
,
(
status
:
any
)
=>
{
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
status
:
any
)
{
const
trackingUrl
=
this
.
parseContent
(
'
tracking_url
'
,
status
);
const
items
=
status
.
split
(
'
:
'
);
if
(
trackingUrl
!==
''
)
{
if
(
items
[
0
]
===
'
tracking_url
'
)
{
trackingUrl
=
items
.
splice
(
1
,
items
.
length
).
join
(
''
)
}
deferred
.
resolve
(
trackingUrl
);
deferred
.
resolve
(
trackingUrl
);
}
});
});
this
.
monitorError
(
this
.
pythonShellClient
,
deferred
);
this
.
monitorError
(
this
.
pythonShellClient
,
deferred
);
return
deferred
.
promise
;
return
deferred
.
promise
;
...
@@ -91,12 +89,11 @@ export class AMLClient {
...
@@ -91,12 +89,11 @@ export class AMLClient {
if
(
this
.
pythonShellClient
===
undefined
)
{
if
(
this
.
pythonShellClient
===
undefined
)
{
throw
Error
(
'
python shell client not initialized!
'
);
throw
Error
(
'
python shell client not initialized!
'
);
}
}
let
newStatus
=
oldStatus
;
this
.
pythonShellClient
.
send
(
'
update_status
'
);
this
.
pythonShellClient
.
send
(
'
update_status
'
);
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
status
:
any
)
{
this
.
pythonShellClient
.
on
(
'
message
'
,
(
status
:
any
)
=>
{
const
items
=
status
.
split
(
'
:
'
);
let
newStatus
=
this
.
parseContent
(
'
status
'
,
status
);
if
(
items
[
0
]
===
'
status
'
)
{
if
(
newStatus
===
'
'
)
{
newStatus
=
items
.
splice
(
1
,
items
.
length
).
join
(
''
)
newStatus
=
oldStatus
;
}
}
deferred
.
resolve
(
newStatus
);
deferred
.
resolve
(
newStatus
);
});
});
...
@@ -117,10 +114,10 @@ export class AMLClient {
...
@@ -117,10 +114,10 @@ export class AMLClient {
throw
Error
(
'
python shell client not initialized!
'
);
throw
Error
(
'
python shell client not initialized!
'
);
}
}
this
.
pythonShellClient
.
send
(
'
receive
'
);
this
.
pythonShellClient
.
send
(
'
receive
'
);
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
command
:
any
)
{
this
.
pythonShellClient
.
on
(
'
message
'
,
(
command
:
any
)
=>
{
const
items
=
command
.
split
(
'
:
'
)
const
message
=
this
.
parseContent
(
'
receive
'
,
command
);
if
(
items
[
0
]
===
'
receive
'
)
{
if
(
message
!==
'
'
)
{
deferred
.
resolve
(
JSON
.
parse
(
command
.
slice
(
8
)
))
deferred
.
resolve
(
JSON
.
parse
(
message
))
}
}
});
});
this
.
monitorError
(
this
.
pythonShellClient
,
deferred
);
this
.
monitorError
(
this
.
pythonShellClient
,
deferred
);
...
@@ -136,4 +133,13 @@ export class AMLClient {
...
@@ -136,4 +133,13 @@ export class AMLClient {
deferred
.
reject
(
error
);
deferred
.
reject
(
error
);
});
});
}
}
// Parse command content, command format is {head}:{content}
public
parseContent
(
head
:
string
,
command
:
string
):
string
{
const
items
=
command
.
split
(
'
:
'
);
if
(
items
[
0
]
===
head
)
{
return
command
.
slice
(
head
.
length
+
1
);
}
return
''
;
}
}
}
src/nni_manager/training_service/reusable/test/amlClient.test.ts
0 → 100644
View file @
bf8be1e7
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import
*
as
chai
from
'
chai
'
;
import
{
cleanupUnitTest
,
prepareUnitTest
}
from
'
../../../common/utils
'
;
import
chaiAsPromised
=
require
(
"
chai-as-promised
"
);
import
{
AMLClient
}
from
'
../aml/amlClient
'
;
describe
(
'
Unit Test for amlClient
'
,
()
=>
{
before
(()
=>
{
chai
.
should
();
chai
.
use
(
chaiAsPromised
);
prepareUnitTest
();
});
after
(()
=>
{
cleanupUnitTest
();
});
it
(
'
test parseContent
'
,
async
()
=>
{
let
amlClient
:
AMLClient
=
new
AMLClient
(
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
);
chai
.
assert
.
equal
(
amlClient
.
parseContent
(
'
test
'
,
'
test:1234
'
),
'
1234
'
,
"
The content should be 1234
"
);
chai
.
assert
.
equal
(
amlClient
.
parseContent
(
'
test
'
,
'
abcd:1234
'
),
''
,
"
The content should be null
"
);
});
});
src/sdk/pynni/nni/compression/tensorflow/compressor.py
View file @
bf8be1e7
...
@@ -6,7 +6,10 @@ Abstract base classes for TensorFlow model compression.
...
@@ -6,7 +6,10 @@ Abstract base classes for TensorFlow model compression.
"""
"""
import
logging
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
assert
tf
.
__version__
.
startswith
(
'2'
),
'NNI model compression only supports TensorFlow v2.x'
from
.
import
default_layers
from
.
import
default_layers
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -25,9 +28,9 @@ class LayerInfo:
...
@@ -25,9 +28,9 @@ class LayerInfo:
The layer's name. Note that it's local to sub-model and may differ from its attribute name.
The layer's name. Note that it's local to sub-model and may differ from its attribute name.
type : str
type : str
Name of the layer's class.
Name of the layer's class.
path : list of str
/
int
path : list of str
or tuple of (str,
int
)
The layer object's and its parents' attribute name / list index.
The layer object's and its parents' attribute name / list index.
For example, if the path is `['cells', 2, 'conv']`, then the layer can be accessed as `model.cells[2].conv`.
For example, if the path is `[
(
'cells', 2
)
, 'conv']`, then the layer can be accessed as `model.cells[2].conv`.
config : JSON object
config : JSON object
Selected configuration for this layer. The format is detailed in tutorial.
Selected configuration for this layer. The format is detailed in tutorial.
...
@@ -35,7 +38,7 @@ class LayerInfo:
...
@@ -35,7 +38,7 @@ class LayerInfo:
----------
----------
layer : tf.keras.layers.Layer
layer : tf.keras.layers.Layer
See attributes section.
See attributes section.
path : list of str
/
int
path : list of str
or tuple of (str,
int
)
See attributes section.
See attributes section.
"""
"""
...
@@ -75,6 +78,8 @@ class Compressor:
...
@@ -75,6 +78,8 @@ class Compressor:
def
__init__
(
self
,
LayerWrapperClass
,
model
,
config_list
):
def
__init__
(
self
,
LayerWrapperClass
,
model
,
config_list
):
assert
isinstance
(
model
,
tf
.
keras
.
Model
)
assert
isinstance
(
model
,
tf
.
keras
.
Model
)
if
isinstance
(
model
,
tf
.
keras
.
Sequential
):
raise
ValueError
(
'NNI model compression does not support `Sequential` model for now'
)
self
.
validate_config
(
model
,
config_list
)
self
.
validate_config
(
model
,
config_list
)
self
.
bound_model
=
model
self
.
bound_model
=
model
...
@@ -204,10 +209,12 @@ class PrunerLayerWrapper(tf.keras.Model):
...
@@ -204,10 +209,12 @@ class PrunerLayerWrapper(tf.keras.Model):
for
weight
in
self
.
layer
.
weights
:
for
weight
in
self
.
layer
.
weights
:
mask
=
self
.
masks
.
get
(
weight
.
name
)
mask
=
self
.
masks
.
get
(
weight
.
name
)
if
mask
is
not
None
:
if
mask
is
not
None
:
new_weights
.
append
(
tf
.
math
.
multiply
(
weight
,
mask
)
.
numpy
()
)
new_weights
.
append
(
tf
.
math
.
multiply
(
weight
,
mask
))
else
:
else
:
new_weights
.
append
(
weight
.
numpy
())
new_weights
.
append
(
weight
)
self
.
layer
.
set_weights
(
new_weights
)
if
new_weights
and
not
hasattr
(
new_weights
[
0
],
'numpy'
):
raise
RuntimeError
(
'NNI: Compressed model can only run in eager mode'
)
self
.
layer
.
set_weights
([
weight
.
numpy
()
for
weight
in
new_weights
])
return
self
.
layer
(
*
inputs
)
return
self
.
layer
(
*
inputs
)
...
@@ -244,26 +251,21 @@ def _locate_layers(model, cur_path=[]):
...
@@ -244,26 +251,21 @@ def _locate_layers(model, cur_path=[]):
# and to my knowledge `Layer.name` is only useful for read-only access.
# and to my knowledge `Layer.name` is only useful for read-only access.
# `cur_path`s format is documented in `LayerInfo.path`.
# `cur_path`s format is documented in `LayerInfo.path`.
# TODO: it can only find layers in `Model` and `list` for now.
# TODO: it can only find layers in `Model` and `list` for now.
assert
isinstance
(
model
,
tf
.
keras
.
Model
)
if
isinstance
(
model
,
tf
.
keras
.
Sequential
):
_logger
.
warning
(
'`Sequential` model is not supported yet, ignored.'
)
ret
=
{}
ret
=
{}
if
isinstance
(
model
,
tf
.
keras
.
Model
):
for
key
,
value
in
model
.
__dict__
.
items
():
for
key
,
value
in
model
.
__dict__
.
items
():
if
isinstance
(
value
,
tf
.
keras
.
Model
):
if
isinstance
(
value
,
tf
.
keras
.
Model
):
ret
.
update
(
_locate_layers
(
value
,
cur_path
+
[
key
]))
ret
.
update
(
_locate_layers
(
value
,
cur_path
+
[
key
]))
elif
isinstance
(
value
,
list
):
ret
.
update
(
_locate_layers
(
value
,
cur_path
+
[
key
]))
elif
isinstance
(
value
,
tf
.
keras
.
layers
.
Layer
):
elif
isinstance
(
value
,
tf
.
keras
.
layers
.
Layer
):
ret
[
id
(
value
)]
=
LayerInfo
(
value
,
cur_path
+
[
key
])
ret
[
id
(
value
)]
=
LayerInfo
(
value
,
cur_path
+
[
key
])
elif
isinstance
(
value
,
list
):
elif
isinstance
(
model
,
list
):
for
i
,
item
in
enumerate
(
value
):
for
i
,
item
in
enumerate
(
model
):
if
isinstance
(
item
,
tf
.
keras
.
Model
):
if
isinstance
(
item
,
tf
.
keras
.
Model
):
ret
.
update
(
_locate_layers
(
item
,
cur_path
+
[
i
]))
ret
.
update
(
_locate_layers
(
item
,
cur_path
+
[
(
key
,
i
)
]))
elif
isinstance
(
item
,
tf
.
keras
.
layers
.
Layer
):
elif
isinstance
(
item
,
tf
.
keras
.
layers
.
Layer
):
ret
[
id
(
item
)]
=
LayerInfo
(
item
,
cur_path
+
[
i
])
ret
[
id
(
item
)]
=
LayerInfo
(
item
,
cur_path
+
[(
key
,
i
)])
else
:
raise
ValueError
(
'Unexpected model type: {}'
.
format
(
type
(
model
)))
return
ret
return
ret
def
_select_config
(
layer_info
,
config_list
):
def
_select_config
(
layer_info
,
config_list
):
...
@@ -289,12 +291,17 @@ def _instrument_model(model, wrappers):
...
@@ -289,12 +291,17 @@ def _instrument_model(model, wrappers):
for
wrapper
in
reversed
(
wrappers
):
for
wrapper
in
reversed
(
wrappers
):
cur
=
model
cur
=
model
for
key
in
wrapper
.
layer_info
.
path
[:
-
1
]:
for
key
in
wrapper
.
layer_info
.
path
[:
-
1
]:
if
isinstance
(
key
,
int
):
if
isinstance
(
key
,
str
):
cur
=
cur
[
key
]
else
:
cur
=
getattr
(
cur
,
key
)
cur
=
getattr
(
cur
,
key
)
key
=
wrapper
.
layer_info
.
path
[
-
1
]
if
isinstance
(
key
,
int
):
cur
[
key
]
=
wrapper
else
:
else
:
name
,
index
=
key
cur
=
getattr
(
cur
,
name
)[
index
]
key
=
wrapper
.
layer_info
.
path
[
-
1
]
if
isinstance
(
key
,
str
):
setattr
(
cur
,
key
,
wrapper
)
setattr
(
cur
,
key
,
wrapper
)
else
:
name
,
index
=
key
getattr
(
cur
,
name
)[
index
]
=
wrapper
#if isinstance(cur, tf.keras.Sequential):
# cur._graph_initialized = False
# cur._layer_call_argspecs[wrapper] = cur._layer_call_argspecs[wrapper.layer]
src/sdk/pynni/nni/compression/tensorflow/pruning/one_shot.py
View file @
bf8be1e7
...
@@ -44,20 +44,24 @@ class LevelPrunerMasker(WeightMasker):
...
@@ -44,20 +44,24 @@ class LevelPrunerMasker(WeightMasker):
def
calc_masks
(
self
,
sparsity
,
wrapper
,
wrapper_idx
=
None
):
def
calc_masks
(
self
,
sparsity
,
wrapper
,
wrapper_idx
=
None
):
masks
=
{}
masks
=
{}
for
weight_variable
in
wrapper
.
layer
.
weights
:
for
weight_variable
in
wrapper
.
layer
.
weights
:
if
weight_variable
.
name
==
'bias'
:
if
'bias'
in
weight_variable
.
name
:
continue
continue
k
=
int
(
tf
.
size
(
weight_variable
).
numpy
()
*
sparsity
)
num_prune
=
int
(
tf
.
size
(
weight_variable
).
numpy
()
*
sparsity
)
if
k
==
0
:
if
num_prune
==
0
:
continue
continue
weight
=
weight_variable
.
read_value
()
weight
=
weight_variable
.
read_value
()
if
wrapper
.
masks
.
get
(
weight_variable
.
name
)
is
not
None
:
if
wrapper
.
masks
.
get
(
weight_variable
.
name
)
is
not
None
:
weight
=
tf
.
math
.
multiply
(
weight
,
wrapper
.
masks
[
weight_variable
.
name
])
weight
=
tf
.
math
.
multiply
(
weight
,
wrapper
.
masks
[
weight_variable
.
name
])
w_abs
=
tf
.
math
.
abs
(
tf
.
reshape
(
weight
,
[
-
1
]))
w_abs
=
tf
.
math
.
abs
(
weight
)
threshold
=
tf
.
math
.
top_k
(
w_abs
,
k
)[
0
][
0
]
k
=
tf
.
size
(
weight
)
-
num_prune
mask
=
tf
.
math
.
greater
(
w_abs
,
threshold
)
topk
=
tf
.
math
.
top_k
(
tf
.
reshape
(
w_abs
,
[
-
1
]),
k
)[
0
]
if
tf
.
size
(
topk
)
==
0
:
mask
=
tf
.
zeros_like
(
weight
)
else
:
mask
=
tf
.
math
.
greater_equal
(
w_abs
,
topk
[
-
1
])
masks
[
weight_variable
.
name
]
=
tf
.
cast
(
mask
,
weight
.
dtype
)
masks
[
weight_variable
.
name
]
=
tf
.
cast
(
mask
,
weight
.
dtype
)
return
masks
return
masks
...
...
src/sdk/pynni/nni/compression/torch/pruning/sensitivity_pruner.py
View file @
bf8be1e7
...
@@ -17,7 +17,7 @@ from ..utils.sensitivity_analysis import SensitivityAnalysis
...
@@ -17,7 +17,7 @@ from ..utils.sensitivity_analysis import SensitivityAnalysis
MAX_PRUNE_RATIO_PER_ITER
=
0.95
MAX_PRUNE_RATIO_PER_ITER
=
0.95
_logger
=
logging
.
getLogger
(
'Sensitivity_Pruner'
)
_logger
=
logging
.
getLogger
(
'Sensitivity_Pruner'
)
_logger
.
setLevel
(
logging
.
INFO
)
class
SensitivityPruner
(
Pruner
):
class
SensitivityPruner
(
Pruner
):
"""
"""
...
@@ -202,10 +202,10 @@ class SensitivityPruner(Pruner):
...
@@ -202,10 +202,10 @@ class SensitivityPruner(Pruner):
prune_ratios
=
sorted
(
sensitivities
[
layer
].
keys
())
prune_ratios
=
sorted
(
sensitivities
[
layer
].
keys
())
last_ratio
=
0
last_ratio
=
0
for
ratio
in
prune_ratios
:
for
ratio
in
prune_ratios
:
last_ratio
=
ratio
cur_acc
=
sensitivities
[
layer
][
ratio
]
cur_acc
=
sensitivities
[
layer
][
ratio
]
if
cur_acc
+
threshold
<
ori_acc
:
if
cur_acc
+
threshold
<
ori_acc
:
break
break
last_ratio
=
ratio
max_ratio
[
layer
]
=
last_ratio
max_ratio
[
layer
]
=
last_ratio
return
max_ratio
return
max_ratio
...
@@ -244,6 +244,7 @@ class SensitivityPruner(Pruner):
...
@@ -244,6 +244,7 @@ class SensitivityPruner(Pruner):
# MAX_PRUNE_RATIO_PER_ITER we rescal all prune
# MAX_PRUNE_RATIO_PER_ITER we rescal all prune
# ratios under this threshold
# ratios under this threshold
if
_Max
>
MAX_PRUNE_RATIO_PER_ITER
:
if
_Max
>
MAX_PRUNE_RATIO_PER_ITER
:
for
layername
in
ratios
:
for
layername
in
ratios
:
ratios
[
layername
]
=
ratios
[
layername
]
*
\
ratios
[
layername
]
=
ratios
[
layername
]
*
\
MAX_PRUNE_RATIO_PER_ITER
/
_Max
MAX_PRUNE_RATIO_PER_ITER
/
_Max
...
@@ -317,6 +318,7 @@ class SensitivityPruner(Pruner):
...
@@ -317,6 +318,7 @@ class SensitivityPruner(Pruner):
finetune_kwargs
=
{}
finetune_kwargs
=
{}
if
self
.
ori_acc
is
None
:
if
self
.
ori_acc
is
None
:
self
.
ori_acc
=
self
.
evaluator
(
*
eval_args
,
**
eval_kwargs
)
self
.
ori_acc
=
self
.
evaluator
(
*
eval_args
,
**
eval_kwargs
)
assert
isinstance
(
self
.
ori_acc
,
float
)
or
isinstance
(
self
.
ori_acc
,
int
)
if
not
resume_sensitivity
:
if
not
resume_sensitivity
:
self
.
sensitivities
=
self
.
analyzer
.
analysis
(
self
.
sensitivities
=
self
.
analyzer
.
analysis
(
val_args
=
eval_args
,
val_kwargs
=
eval_kwargs
)
val_args
=
eval_args
,
val_kwargs
=
eval_kwargs
)
...
@@ -330,6 +332,7 @@ class SensitivityPruner(Pruner):
...
@@ -330,6 +332,7 @@ class SensitivityPruner(Pruner):
iteration_count
=
0
iteration_count
=
0
if
self
.
checkpoint_dir
is
not
None
:
if
self
.
checkpoint_dir
is
not
None
:
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
modules_wrapper_final
=
None
while
cur_ratio
>
target_ratio
:
while
cur_ratio
>
target_ratio
:
iteration_count
+=
1
iteration_count
+=
1
# Each round have three steps:
# Each round have three steps:
...
@@ -343,9 +346,16 @@ class SensitivityPruner(Pruner):
...
@@ -343,9 +346,16 @@ class SensitivityPruner(Pruner):
# layers according to the sensitivity result
# layers according to the sensitivity result
proportion
=
self
.
sparsity_proportion_calc
(
proportion
=
self
.
sparsity_proportion_calc
(
ori_acc
,
self
.
acc_drop_threshold
,
self
.
sensitivities
)
ori_acc
,
self
.
acc_drop_threshold
,
self
.
sensitivities
)
new_pruneratio
=
self
.
normalize
(
proportion
,
self
.
sparsity_per_iter
)
new_pruneratio
=
self
.
normalize
(
proportion
,
self
.
sparsity_per_iter
)
cfg_list
=
self
.
create_cfg
(
new_pruneratio
)
cfg_list
=
self
.
create_cfg
(
new_pruneratio
)
if
not
cfg_list
:
_logger
.
error
(
'The threshold is too small, please set a larger threshold'
)
return
self
.
model
_logger
.
debug
(
'Pruner Config: %s'
,
str
(
cfg_list
))
_logger
.
debug
(
'Pruner Config: %s'
,
str
(
cfg_list
))
cfg_str
=
[
'%s:%.3f'
%
(
cfg
[
'op_names'
][
0
],
cfg
[
'sparsity'
])
for
cfg
in
cfg_list
]
_logger
.
info
(
'Current Sparsities: %s'
,
','
.
join
(
cfg_str
))
pruner
=
self
.
Pruner
(
self
.
model
,
cfg_list
)
pruner
=
self
.
Pruner
(
self
.
model
,
cfg_list
)
pruner
.
compress
()
pruner
.
compress
()
pruned_acc
=
self
.
evaluator
(
*
eval_args
,
**
eval_kwargs
)
pruned_acc
=
self
.
evaluator
(
*
eval_args
,
**
eval_kwargs
)
...
@@ -367,6 +377,7 @@ class SensitivityPruner(Pruner):
...
@@ -367,6 +377,7 @@ class SensitivityPruner(Pruner):
self
.
analyzer
.
already_pruned
[
name
]
=
sparsity
self
.
analyzer
.
already_pruned
[
name
]
=
sparsity
# update the cur_ratio
# update the cur_ratio
cur_ratio
=
1
-
self
.
current_sparsity
()
cur_ratio
=
1
-
self
.
current_sparsity
()
modules_wrapper_final
=
pruner
.
get_modules_wrapper
()
del
pruner
del
pruner
_logger
.
info
(
'Currently remained weights: %f'
,
cur_ratio
)
_logger
.
info
(
'Currently remained weights: %f'
,
cur_ratio
)
...
@@ -383,14 +394,19 @@ class SensitivityPruner(Pruner):
...
@@ -383,14 +394,19 @@ class SensitivityPruner(Pruner):
with
open
(
cfg_path
,
'w'
)
as
jf
:
with
open
(
cfg_path
,
'w'
)
as
jf
:
json
.
dump
(
cfg_list
,
jf
)
json
.
dump
(
cfg_list
,
jf
)
self
.
analyzer
.
export
(
sensitivity_path
)
self
.
analyzer
.
export
(
sensitivity_path
)
if
cur_ratio
>
target_ratio
:
if
cur_ratio
>
target_ratio
:
# If this is the last prune iteration, skip the time-consuming
# If this is the last prune iteration, skip the time-consuming
# sensitivity analysis
# sensitivity analysis
self
.
analyzer
.
load_state_dict
(
self
.
model
.
state_dict
())
self
.
analyzer
.
load_state_dict
(
self
.
model
.
state_dict
())
self
.
sensitivities
=
self
.
analyzer
.
analysis
(
self
.
sensitivities
=
self
.
analyzer
.
analysis
(
val_args
=
eval_args
,
val_kwargs
=
eval_kwargs
)
val_args
=
eval_args
,
val_kwargs
=
eval_kwargs
)
_logger
.
info
(
'After Pruning: %.2f weights remains'
,
cur_ratio
)
_logger
.
info
(
'After Pruning: %.2f weights remains'
,
cur_ratio
)
self
.
modules_wrapper
=
modules_wrapper_final
self
.
_wrap_model
()
return
self
.
model
return
self
.
model
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
...
...
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
View file @
bf8be1e7
...
@@ -222,6 +222,10 @@ infer_from_inshape = {
...
@@ -222,6 +222,10 @@ infer_from_inshape = {
'ReLU'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::tanh'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::tanh_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::hardtanh'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::hardtanh_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
...
@@ -282,7 +286,7 @@ def cat_inshape(module_masks, mask, cat_info, last_visited):
...
@@ -282,7 +286,7 @@ def cat_inshape(module_masks, mask, cat_info, last_visited):
Parameters
Parameters
----------
----------
module_masks : ModuleMasks
module_masks : ModuleMasks
The ModuleMasks instance of the
batchnorm
2d
The ModuleMasks instance of the
Conv
2d
mask : CoarseMask
mask : CoarseMask
The mask of its input tensor
The mask of its input tensor
cat_info: dict
cat_info: dict
...
...
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
View file @
bf8be1e7
...
@@ -118,11 +118,14 @@ class CatMaskPadding(MaskFix):
...
@@ -118,11 +118,14 @@ class CatMaskPadding(MaskFix):
continue
continue
# pad the mask for the non-pruned layers
# pad the mask for the non-pruned layers
for
layer
in
layers
:
for
layer
in
layers
:
if
layer
in
self
.
masks
:
continue
module
=
name_to_module
[
layer
]
module
=
name_to_module
[
layer
]
w_shape
=
module
.
weight
.
data
.
size
()
w_shape
=
module
.
weight
.
data
.
size
()
w_mask
=
torch
.
ones
(
w_shape
).
to
(
device
)
w_mask
=
torch
.
ones
(
w_shape
).
to
(
device
)
b_mask
=
None
b_mask
=
None
if
hasattr
(
module
,
'bias'
):
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
# module.bias may be None
b_shape
=
module
.
bias
.
data
.
size
()
b_shape
=
module
.
bias
.
data
.
size
()
b_mask
=
torch
.
ones
(
b_shape
).
to
(
device
)
b_mask
=
torch
.
ones
(
b_shape
).
to
(
device
)
self
.
masks
[
layer
]
=
{
'weight'
:
w_mask
,
'bias'
:
b_mask
}
self
.
masks
[
layer
]
=
{
'weight'
:
w_mask
,
'bias'
:
b_mask
}
...
...
src/sdk/pynni/nni/compression/torch/utils/sensitivity_analysis.py
View file @
bf8be1e7
...
@@ -163,7 +163,7 @@ class SensitivityAnalysis:
...
@@ -163,7 +163,7 @@ class SensitivityAnalysis:
if
val_kwargs
is
None
:
if
val_kwargs
is
None
:
val_kwargs
=
{}
val_kwargs
=
{}
# Get the original validation metric(accuracy/loss) before pruning
# Get the original validation metric(accuracy/loss) before pruning
if
self
.
ori_metric
is
None
:
# Get the accuracy baseline before starting the analysis.
self
.
ori_metric
=
self
.
val_func
(
*
val_args
,
**
val_kwargs
)
self
.
ori_metric
=
self
.
val_func
(
*
val_args
,
**
val_kwargs
)
namelist
=
list
(
self
.
target_layer
.
keys
())
namelist
=
list
(
self
.
target_layer
.
keys
())
if
specified_layers
is
not
None
:
if
specified_layers
is
not
None
:
...
@@ -172,19 +172,21 @@ class SensitivityAnalysis:
...
@@ -172,19 +172,21 @@ class SensitivityAnalysis:
for
name
in
namelist
:
for
name
in
namelist
:
self
.
sensitivities
[
name
]
=
{}
self
.
sensitivities
[
name
]
=
{}
for
sparsity
in
self
.
sparsities
:
for
sparsity
in
self
.
sparsities
:
# here the sparsity is the relative sparsity of the
# the remained weights
# Calculate the actual prune ratio based on the already pruned ratio
# Calculate the actual prune ratio based on the already pruned ratio
sparsity
=
(
real_
sparsity
=
(
1.0
-
self
.
already_pruned
[
name
])
*
sparsity
+
self
.
already_pruned
[
name
]
1.0
-
self
.
already_pruned
[
name
])
*
sparsity
+
self
.
already_pruned
[
name
]
# TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary
# TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary
# I think the L1/L2 Pruner should specify the op_types automaticlly
# I think the L1/L2 Pruner should specify the op_types automaticlly
# according to the op_names
# according to the op_names
cfg
=
[{
'sparsity'
:
sparsity
,
'op_names'
:
[
cfg
=
[{
'sparsity'
:
real_
sparsity
,
'op_names'
:
[
name
],
'op_types'
:
[
'Conv2d'
]}]
name
],
'op_types'
:
[
'Conv2d'
]}]
pruner
=
self
.
Pruner
(
self
.
model
,
cfg
)
pruner
=
self
.
Pruner
(
self
.
model
,
cfg
)
pruner
.
compress
()
pruner
.
compress
()
val_metric
=
self
.
val_func
(
*
val_args
,
**
val_kwargs
)
val_metric
=
self
.
val_func
(
*
val_args
,
**
val_kwargs
)
logger
.
info
(
'Layer: %s Sparsity: %.2f Validation Metric: %.4f'
,
logger
.
info
(
'Layer: %s Sparsity: %.2f Validation Metric: %.4f'
,
name
,
sparsity
,
val_metric
)
name
,
real_
sparsity
,
val_metric
)
self
.
sensitivities
[
name
][
sparsity
]
=
val_metric
self
.
sensitivities
[
name
][
sparsity
]
=
val_metric
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
...
...
src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py
View file @
bf8be1e7
...
@@ -15,7 +15,7 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
...
@@ -15,7 +15,7 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
arch : dict or None
arch : dict or None
If a dict, it is in the format that is described in
If a dict, it is in the format that is described in
:class:`nni.nas.benchmark.nasbench101.Nb101TrialConfig`. Only trial stats
:class:`nni.nas.benchmark.nasbench101.Nb101TrialConfig`. Only trial stats
matched will be returned. If none, architecture
will be a wildcar
d.
matched will be returned. If none,
all
architecture
s in the database will be matche
d.
num_epochs : int or None
num_epochs : int or None
If int, matching results will be returned. Otherwise a wildcard.
If int, matching results will be returned. Otherwise a wildcard.
isomorphism : boolean
isomorphism : boolean
...
...
src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py
View file @
bf8be1e7
...
@@ -14,7 +14,7 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
...
@@ -14,7 +14,7 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
arch : dict or None
arch : dict or None
If a dict, it is in the format that is described in
If a dict, it is in the format that is described in
:class:`nni.nas.benchmark.nasbench201.Nb201TrialConfig`. Only trial stats
:class:`nni.nas.benchmark.nasbench201.Nb201TrialConfig`. Only trial stats
matched will be returned. If none, architecture
will be a wildcar
d.
matched will be returned. If none,
all
architecture
s in the database will be matche
d.
num_epochs : int or None
num_epochs : int or None
If int, matching results will be returned. Otherwise a wildcard.
If int, matching results will be returned. Otherwise a wildcard.
dataset : str or None
dataset : str or None
...
...
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
bf8be1e7
...
@@ -162,7 +162,7 @@ class Mutator(BaseMutator):
...
@@ -162,7 +162,7 @@ class Mutator(BaseMutator):
if
self
.
_connect_all
:
if
self
.
_connect_all
:
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
[
op
(
*
args
,
**
kwargs
)
for
op
in
mutable
]),
\
[
op
(
*
args
,
**
kwargs
)
for
op
in
mutable
]),
\
torch
.
ones
(
len
(
mutable
))
torch
.
ones
(
len
(
mutable
))
.
bool
()
def
_map_fn
(
op
,
args
,
kwargs
):
def
_map_fn
(
op
,
args
,
kwargs
):
return
op
(
*
args
,
**
kwargs
)
return
op
(
*
args
,
**
kwargs
)
...
@@ -192,7 +192,7 @@ class Mutator(BaseMutator):
...
@@ -192,7 +192,7 @@ class Mutator(BaseMutator):
"""
"""
if
self
.
_connect_all
:
if
self
.
_connect_all
:
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
tensor_list
),
\
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
tensor_list
),
\
torch
.
ones
(
mutable
.
n_candidates
)
torch
.
ones
(
mutable
.
n_candidates
)
.
bool
()
mask
=
self
.
_get_decision
(
mutable
)
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
mutable
.
n_candidates
)
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
mutable
.
n_candidates
)
...
...
src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
View file @
bf8be1e7
...
@@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module):
...
@@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module):
"""
"""
def
__init__
(
self
,
num_nodes
,
in_channels_pp
,
in_channels_p
,
out_channels
,
reduction
):
def
__init__
(
self
,
num_nodes
,
in_channels_pp
,
in_channels_p
,
out_channels
,
reduction
):
super
().
__init__
()
super
().
__init__
()
print
(
in_channels_pp
,
in_channels_p
,
out_channels
,
reduction
)
self
.
reduction
=
reduction
self
.
reduction
=
reduction
if
self
.
reduction
:
if
self
.
reduction
:
self
.
reduce0
=
FactorizedReduce
(
in_channels_pp
,
out_channels
,
affine
=
False
)
self
.
reduce0
=
FactorizedReduce
(
in_channels_pp
,
out_channels
,
affine
=
False
)
...
@@ -110,7 +109,7 @@ class ENASMicroLayer(nn.Module):
...
@@ -110,7 +109,7 @@ class ENASMicroLayer(nn.Module):
pprev: torch.Tensor
pprev: torch.Tensor
the output of the previous previous layer
the output of the previous previous layer
prev: torch.Tensor
prev: torch.Tensor
the output of the previous
previous
layer
the output of the previous layer
"""
"""
if
self
.
reduction
:
if
self
.
reduction
:
pprev
,
prev
=
self
.
reduce0
(
pprev
),
self
.
reduce1
(
prev
)
pprev
,
prev
=
self
.
reduce0
(
pprev
),
self
.
reduce1
(
prev
)
...
@@ -160,7 +159,7 @@ class ENASMacroLayer(mutables.MutableScope):
...
@@ -160,7 +159,7 @@ class ENASMacroLayer(mutables.MutableScope):
PoolBranch
(
'avg'
,
in_filters
,
out_filters
,
3
,
1
,
1
),
PoolBranch
(
'avg'
,
in_filters
,
out_filters
,
3
,
1
,
1
),
PoolBranch
(
'max'
,
in_filters
,
out_filters
,
3
,
1
,
1
)
PoolBranch
(
'max'
,
in_filters
,
out_filters
,
3
,
1
,
1
)
])
])
if
prev_labels
>
0
:
if
prev_labels
:
self
.
skipconnect
=
mutables
.
InputChoice
(
choose_from
=
prev_labels
,
n_chosen
=
None
)
self
.
skipconnect
=
mutables
.
InputChoice
(
choose_from
=
prev_labels
,
n_chosen
=
None
)
else
:
else
:
self
.
skipconnect
=
None
self
.
skipconnect
=
None
...
...
src/sdk/pynni/nni/package_utils.py
View file @
bf8be1e7
...
@@ -286,32 +286,9 @@ def create_customized_class_instance(class_params):
...
@@ -286,32 +286,9 @@ def create_customized_class_instance(class_params):
return
instance
return
instance
def
get_python_dir
(
sitepackages_path
):
if
sys
.
platform
==
"win32"
:
return
str
(
Path
(
sitepackages_path
))
else
:
return
str
(
Path
(
sitepackages_path
).
parents
[
2
])
def
get_nni_installation_parent_dir
():
def
get_nni_installation_parent_dir
():
''' Find nni installation parent directory
''' Find nni installation parent directory
'''
'''
def
try_installation_path_sequentially
(
*
sitepackages
):
'''Try different installation path sequentially util nni is found.
Return None if nothing is found
'''
def
_generate_installation_path
(
sitepackages_path
):
python_dir
=
get_python_dir
(
sitepackages_path
)
entry_file
=
os
.
path
.
join
(
python_dir
,
'nni'
,
'main.js'
)
if
os
.
path
.
isfile
(
entry_file
):
return
python_dir
return
None
for
sitepackage
in
sitepackages
:
python_dir
=
_generate_installation_path
(
sitepackage
)
if
python_dir
:
return
python_dir
return
None
if
os
.
getenv
(
'VIRTUAL_ENV'
):
if
os
.
getenv
(
'VIRTUAL_ENV'
):
# if 'virtualenv' package is used, `site` has not attr getsitepackages, so we will instead use VIRTUAL_ENV
# if 'virtualenv' package is used, `site` has not attr getsitepackages, so we will instead use VIRTUAL_ENV
# Note that conda venv will not have VIRTUAL_ENV
# Note that conda venv will not have VIRTUAL_ENV
...
@@ -321,12 +298,23 @@ def get_nni_installation_parent_dir():
...
@@ -321,12 +298,23 @@ def get_nni_installation_parent_dir():
# If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given
# If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given
# that nni exists there
# that nni exists there
if
python_sitepackage
.
startswith
(
'/usr'
)
or
python_sitepackage
.
startswith
(
'/Library'
):
if
python_sitepackage
.
startswith
(
'/usr'
)
or
python_sitepackage
.
startswith
(
'/Library'
):
python_dir
=
try_installation_path_sequentially
(
site
.
getusersitepackages
(),
site
.
getsitepackages
()
[
0
]
)
python_dir
=
_
try_installation_path_sequentially
(
site
.
getusersitepackages
(),
*
site
.
getsitepackages
())
else
:
else
:
python_dir
=
try_installation_path_sequentially
(
site
.
getsitepackages
()[
0
],
site
.
getusersitepackages
())
python_dir
=
_try_installation_path_sequentially
(
*
site
.
getsitepackages
(),
site
.
getusersitepackages
())
return
python_dir
return
python_dir
def
_try_installation_path_sequentially
(
*
sitepackages
):
'''Try different installation path sequentially util nni is found.
Return None if nothing is found
'''
for
sitepackage
in
sitepackages
:
path
=
Path
(
sitepackage
)
if
len
(
path
.
parents
)
>
2
and
(
path
.
parents
[
2
]
/
'nni'
/
'main.js'
).
is_file
():
return
str
(
path
.
parents
[
2
])
if
(
path
/
'nni'
/
'main.js'
).
is_file
():
return
str
(
path
)
return
None
def
get_nni_installation_path
():
def
get_nni_installation_path
():
''' Find nni installation directory
''' Find nni installation directory
'''
'''
...
...
src/sdk/pynni/tests/test_compressor_tf.py
0 → 100644
View file @
bf8be1e7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
unittest
import
numpy
as
np
import
tensorflow
as
tf
####
#
# This file tests pruners on 2 models: a classic CNN model, and a naive model with one linear layer
#
# The CNN model is used to test layer detecting and instrumenting.
#
# The naive model is used to test mask calculation.
# It has a single 10x10 linear layer without bias, and `reduce_sum` its result.
# To help predicting pruning result, the linear layer has fixed initial weights:
# [ [ 0.0, 1.0, 2.0, ..., 9.0 ], [0.1, 1.1, 2.1, ..., 9.1 ], ... , [0.9, 1.0, 2.9, ..., 9.9 ] ]
#
####
# This tensor is used as input of 10x10 linear layer, the first dimension is batch size
tensor1x10
=
tf
.
constant
([[
1.0
]
*
10
])
@
unittest
.
skipIf
(
tf
.
__version__
[
0
]
!=
'2'
,
'Skip TF 1.x setup'
)
class
TfCompressorTestCase
(
unittest
.
TestCase
):
def
test_layer_detection
(
self
):
# Conv and dense layers should be compressed, pool and flatten should not.
# This also tests instrumenting functionality.
self
.
_test_layer_detection_on_model
(
CnnModel
())
def
_test_layer_detection_on_model
(
self
,
model
):
pruner
=
pruners
[
'level'
](
model
)
pruner
.
compress
()
layer_types
=
sorted
(
wrapper
.
layer_info
.
type
for
wrapper
in
pruner
.
wrappers
)
assert
layer_types
==
[
'Conv2D'
,
'Dense'
,
'Dense'
],
layer_types
def
test_level_pruner
(
self
):
# prune 90% : 9.0 + 9.1 + ... + 9.9 = 94.5
model
=
build_naive_model
()
pruners
[
'level'
](
model
).
compress
()
x
=
model
(
tensor1x10
)
assert
x
.
numpy
()
==
94.5
try
:
from
tensorflow.keras
import
Model
,
Sequential
from
tensorflow.keras.layers
import
(
Conv2D
,
Dense
,
Flatten
,
MaxPool2D
)
from
nni.compression.tensorflow
import
LevelPruner
pruners
=
{
'level'
:
(
lambda
model
:
LevelPruner
(
model
,
[{
'sparsity'
:
0.9
,
'op_types'
:
[
'default'
]}])),
}
class
CnnModel
(
Model
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
Conv2D
(
filters
=
10
,
kernel_size
=
3
,
activation
=
'relu'
)
self
.
pool
=
MaxPool2D
(
pool_size
=
2
)
self
.
flatten
=
Flatten
()
self
.
fc1
=
Dense
(
units
=
10
,
activation
=
'relu'
)
self
.
fc2
=
Dense
(
units
=
5
,
activation
=
'softmax'
)
def
call
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
pool
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
fc2
(
x
)
return
x
class
NaiveModel
(
Model
):
def
__init__
(
self
):
super
().
__init__
()
self
.
fc
=
Dense
(
units
=
10
,
use_bias
=
False
)
def
call
(
self
,
x
):
return
tf
.
math
.
reduce_sum
(
self
.
fc
(
x
))
except
Exception
:
pass
def
build_naive_model
():
model
=
NaiveModel
()
model
.
build
(
tensor1x10
.
shape
)
weight
=
[[(
i
+
j
*
0.1
)
for
i
in
range
(
10
)]
for
j
in
range
(
10
)]
model
.
set_weights
([
np
.
array
(
weight
)])
return
model
if
__name__
==
'__main__'
:
unittest
.
main
()
src/sdk/pynni/tests/test_model_speedup.py
View file @
bf8be1e7
...
@@ -145,7 +145,7 @@ class SpeedupTestCase(TestCase):
...
@@ -145,7 +145,7 @@ class SpeedupTestCase(TestCase):
assert
model
.
backbone2
.
fc1
.
in_features
==
int
(
orig_model
.
backbone2
.
fc1
.
in_features
*
SPARSITY
)
assert
model
.
backbone2
.
fc1
.
in_features
==
int
(
orig_model
.
backbone2
.
fc1
.
in_features
*
SPARSITY
)
def
test_speedup_integration
(
self
):
def
test_speedup_integration
(
self
):
for
model_name
in
[
'resnet18'
,
'squeezenet1_1'
,
'mobilenet_v2'
,
'densenet121'
,
'inception_v3'
]:
for
model_name
in
[
'resnet18'
,
'squeezenet1_1'
,
'mobilenet_v2'
,
'densenet121'
,
'densenet169'
,
'inception_v3'
]:
Model
=
getattr
(
models
,
model_name
)
Model
=
getattr
(
models
,
model_name
)
net
=
Model
(
pretrained
=
True
,
progress
=
False
).
to
(
device
)
net
=
Model
(
pretrained
=
True
,
progress
=
False
).
to
(
device
)
speedup_model
=
Model
().
to
(
device
)
speedup_model
=
Model
().
to
(
device
)
...
...
src/webui/src/components/Modals/Compare.tsx
View file @
bf8be1e7
...
@@ -85,8 +85,10 @@ class Compare extends React.Component<CompareProps, {}> {
...
@@ -85,8 +85,10 @@ class Compare extends React.Component<CompareProps, {}> {
containLabel
:
true
containLabel
:
true
},
},
legend
:
{
legend
:
{
// more than 10 trials will hide legend
type
:
'
scroll
'
,
data
:
idsList
.
length
>
10
?
null
:
idsList
right
:
40
,
left
:
idsList
.
length
>
6
?
80
:
null
,
data
:
idsList
},
},
xAxis
:
{
xAxis
:
{
type
:
'
category
'
,
type
:
'
category
'
,
...
@@ -135,8 +137,17 @@ class Compare extends React.Component<CompareProps, {}> {
...
@@ -135,8 +137,17 @@ class Compare extends React.Component<CompareProps, {}> {
isComplexSearchSpace
=
(
typeof
parameterList
[
0
][
parameterKeys
[
0
]]
===
'
object
'
)
isComplexSearchSpace
=
(
typeof
parameterList
[
0
][
parameterKeys
[
0
]]
===
'
object
'
)
?
true
:
false
;
?
true
:
false
;
}
}
const
width
=
this
.
getWebUIWidth
();
let
scrollClass
;
if
(
width
>
1200
)
{
scrollClass
=
idList
.
length
>
3
?
'
flex
'
:
''
;
}
else
if
(
width
<
700
)
{
scrollClass
=
idList
.
length
>
1
?
'
flex
'
:
''
;
}
else
{
scrollClass
=
idList
.
length
>
2
?
'
flex
'
:
''
;
}
return
(
return
(
<
table
className
=
"
compare-modal-table
"
>
<
table
className
=
{
`
compare-modal-table
${
scrollClass
}
`
}
>
<
tbody
>
<
tbody
>
<
tr
>
<
tr
>
<
td
className
=
"column"
>
Id
</
td
>
<
td
className
=
"column"
>
Id
</
td
>
...
@@ -200,6 +211,10 @@ class Compare extends React.Component<CompareProps, {}> {
...
@@ -200,6 +211,10 @@ class Compare extends React.Component<CompareProps, {}> {
);
);
}
}
getWebUIWidth
=
():
number
=>
{
return
window
.
innerWidth
;
}
componentDidMount
():
void
{
componentDidMount
():
void
{
this
.
_isCompareMount
=
true
;
this
.
_isCompareMount
=
true
;
}
}
...
...
src/webui/src/components/public-child/PaiTrialChild.tsx
View file @
bf8be1e7
import
*
as
React
from
'
react
'
;
import
*
as
React
from
'
react
'
;
import
{
DOWNLOAD_IP
}
from
'
../../static/const
'
;
import
{
DOWNLOAD_IP
}
from
'
../../static/const
'
;
import
LogPathChild
from
'
./LogPathChild
'
;
interface
PaiTrialChildProps
{
interface
PaiTrialChildProps
{
logString
:
string
;
logString
:
string
;
...
@@ -21,7 +22,7 @@ class PaiTrialChild extends React.Component<PaiTrialChildProps, {}> {
...
@@ -21,7 +22,7 @@ class PaiTrialChild extends React.Component<PaiTrialChildProps, {}> {
{
{
logString
===
''
logString
===
''
?
?
<
div
/>
null
:
:
<
div
>
<
div
>
{
{
...
@@ -33,10 +34,13 @@ class PaiTrialChild extends React.Component<PaiTrialChildProps, {}> {
...
@@ -33,10 +34,13 @@ class PaiTrialChild extends React.Component<PaiTrialChildProps, {}> {
href
=
{
`
${
DOWNLOAD_IP
}
/trial_
${
id
}
.log`
}
href
=
{
`
${
DOWNLOAD_IP
}
/trial_
${
id
}
.log`
}
style
=
{
{
marginRight
:
10
}
}
style
=
{
{
marginRight
:
10
}
}
>
>
t
rial stdout
T
rial stdout
</
a
>
</
a
>
:
:
<
span
>
trial stdout:
{
logString
}
</
span
>
<
LogPathChild
eachLogpath
=
{
logString
}
logName
=
"Trial stdout:"
/>
}
}
</
div
>
</
div
>
}
}
...
...
src/webui/src/components/public-child/PaiTrialLog.tsx
View file @
bf8be1e7
...
@@ -42,7 +42,7 @@ class PaitrialLog extends React.Component<PaitrialLogProps, {}> {
...
@@ -42,7 +42,7 @@ class PaitrialLog extends React.Component<PaitrialLogProps, {}> {
>
>
Trial stdout
Trial stdout
</
a
>
</
a
>
<
a
target
=
"_blank"
rel
=
"noopener noreferrer"
href
=
{
logStr
.
split
(
'
,
'
)[
1
]
}
>
hdfsL
og
</
a
>
<
a
target
=
"_blank"
rel
=
"noopener noreferrer"
href
=
{
logStr
.
split
(
'
,
'
)[
1
]
}
>
NFS l
og
</
a
>
</
div
>
</
div
>
:
:
<
div
>
<
div
>
...
@@ -52,7 +52,7 @@ class PaitrialLog extends React.Component<PaitrialLogProps, {}> {
...
@@ -52,7 +52,7 @@ class PaitrialLog extends React.Component<PaitrialLogProps, {}> {
/>
/>
<
LogPathChild
<
LogPathChild
eachLogpath
=
{
logStr
.
split
(
'
,
'
)[
1
]
}
eachLogpath
=
{
logStr
.
split
(
'
,
'
)[
1
]
}
logName
=
"Log on
HD
FS:"
logName
=
"Log on
N
FS:"
/>
/>
</
div
>
</
div
>
}
}
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment