Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4412001f
Commit
4412001f
authored
Sep 02, 2020
by
A. Unique TensorFlower
Browse files
Merge pull request #8604 from Ruomei:toupstream/clusteringexample
PiperOrigin-RevId: 329802437
parents
faea89d9
a87bb185
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
11 deletions
+12
-11
official/benchmark/models/resnet_imagenet_main.py
official/benchmark/models/resnet_imagenet_main.py
+12
-11
No files found.
official/benchmark/models/resnet_imagenet_main.py
View file @
4412001f
...
@@ -37,15 +37,16 @@ from official.vision.image_classification.resnet import imagenet_preprocessing
...
@@ -37,15 +37,16 @@ from official.vision.image_classification.resnet import imagenet_preprocessing
from
official.vision.image_classification.resnet
import
resnet_model
from
official.vision.image_classification.resnet
import
resnet_model
def
cluster_last_three_conv2d_layers
(
model
):
def
_cluster_last_three_conv2d_layers
(
model
):
import
tensorflow_model_optimization
as
tfmot
"""Helper method to cluster last three conv2d layers."""
last_three_conv2d_layers
=
[
import
tensorflow_model_optimization
as
tfmot
# pylint: disable=g-import-not-at-top
last_three_conv2d_layers
=
[
layer
for
layer
in
model
.
layers
layer
for
layer
in
model
.
layers
if
isinstance
(
layer
,
tf
.
keras
.
layers
.
Conv2D
)
if
isinstance
(
layer
,
tf
.
keras
.
layers
.
Conv2D
)
][
-
3
:]
][
-
3
:]
cluster_weights
=
tfmot
.
clustering
.
keras
.
cluster_weights
cluster_weights
=
tfmot
.
clustering
.
keras
.
cluster_weights
C
entroid
I
nitialization
=
tfmot
.
clustering
.
keras
.
CentroidInitialization
c
entroid
_i
nitialization
=
tfmot
.
clustering
.
keras
.
CentroidInitialization
def
cluster_fn
(
layer
):
def
cluster_fn
(
layer
):
if
layer
not
in
last_three_conv2d_layers
:
if
layer
not
in
last_three_conv2d_layers
:
...
@@ -54,12 +55,12 @@ def cluster_last_three_conv2d_layers(model):
...
@@ -54,12 +55,12 @@ def cluster_last_three_conv2d_layers(model):
if
layer
==
last_three_conv2d_layers
[
0
]
or
\
if
layer
==
last_three_conv2d_layers
[
0
]
or
\
layer
==
last_three_conv2d_layers
[
1
]:
layer
==
last_three_conv2d_layers
[
1
]:
clustered
=
cluster_weights
(
layer
,
number_of_clusters
=
256
,
\
clustered
=
cluster_weights
(
layer
,
number_of_clusters
=
256
,
\
cluster_centroids_init
=
C
entroid
I
nitialization
.
LINEAR
)
cluster_centroids_init
=
c
entroid
_i
nitialization
.
LINEAR
)
print
(
"
Clustered {} with 256 clusters
"
.
format
(
layer
.
name
))
print
(
'
Clustered {} with 256 clusters
'
.
format
(
layer
.
name
))
else
:
else
:
clustered
=
cluster_weights
(
layer
,
number_of_clusters
=
32
,
\
clustered
=
cluster_weights
(
layer
,
number_of_clusters
=
32
,
\
cluster_centroids_init
=
C
entroid
I
nitialization
.
LINEAR
)
cluster_centroids_init
=
c
entroid
_i
nitialization
.
LINEAR
)
print
(
"
Clustered {} with 32 clusters
"
.
format
(
layer
.
name
))
print
(
'
Clustered {} with 32 clusters
'
.
format
(
layer
.
name
))
return
clustered
return
clustered
return
tf
.
keras
.
models
.
clone_model
(
model
,
clone_function
=
cluster_fn
)
return
tf
.
keras
.
models
.
clone_model
(
model
,
clone_function
=
cluster_fn
)
...
@@ -228,7 +229,7 @@ def run(flags_obj):
...
@@ -228,7 +229,7 @@ def run(flags_obj):
model
.
load_weights
(
flags_obj
.
pretrained_filepath
)
model
.
load_weights
(
flags_obj
.
pretrained_filepath
)
if
flags_obj
.
pruning_method
==
'polynomial_decay'
:
if
flags_obj
.
pruning_method
==
'polynomial_decay'
:
import
tensorflow_model_optimization
as
tfmot
import
tensorflow_model_optimization
as
tfmot
# pylint: disable=g-import-not-at-top
if
dtype
!=
tf
.
float32
:
if
dtype
!=
tf
.
float32
:
raise
NotImplementedError
(
raise
NotImplementedError
(
'Pruning is currently only supported on dtype=tf.float32.'
)
'Pruning is currently only supported on dtype=tf.float32.'
)
...
@@ -246,12 +247,12 @@ def run(flags_obj):
...
@@ -246,12 +247,12 @@ def run(flags_obj):
raise
NotImplementedError
(
'Only polynomial_decay is currently supported.'
)
raise
NotImplementedError
(
'Only polynomial_decay is currently supported.'
)
if
flags_obj
.
clustering_method
==
'selective_clustering'
:
if
flags_obj
.
clustering_method
==
'selective_clustering'
:
import
tensorflow_model_optimization
as
tfmot
import
tensorflow_model_optimization
as
tfmot
# pylint: disable=g-import-not-at-top
if
dtype
!=
tf
.
float32
or
\
if
dtype
!=
tf
.
float32
or
\
flags_obj
.
fp16_implementation
==
'graph_rewrite'
:
flags_obj
.
fp16_implementation
==
'graph_rewrite'
:
raise
NotImplementedError
(
raise
NotImplementedError
(
'Clustering is currently only supported on dtype=tf.float32.'
)
'Clustering is currently only supported on dtype=tf.float32.'
)
model
=
cluster_last_three_conv2d_layers
(
model
)
model
=
_
cluster_last_three_conv2d_layers
(
model
)
elif
flags_obj
.
clustering_method
:
elif
flags_obj
.
clustering_method
:
raise
NotImplementedError
(
raise
NotImplementedError
(
'Only selective_clustering is implemented.'
)
'Only selective_clustering is implemented.'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment