Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bbdc9810
Commit
bbdc9810
authored
Apr 04, 2020
by
A. Unique TensorFlower
Browse files
Internal changes.
PiperOrigin-RevId: 304805715
parent
5cf005fd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
10 deletions
+25
-10
official/vision/image_classification/efficientnet/common_modules.py
...ision/image_classification/efficientnet/common_modules.py
+18
-0
official/vision/image_classification/efficientnet/efficientnet_config.py
.../image_classification/efficientnet/efficientnet_config.py
+1
-1
official/vision/image_classification/efficientnet/efficientnet_model.py
...n/image_classification/efficientnet/efficientnet_model.py
+6
-9
No files found.
official/vision/image_classification/efficientnet/common_modules.py
View file @
bbdc9810
...
...
@@ -98,3 +98,21 @@ def count_params(model, trainable_only=True):
else
:
return
int
(
np
.
sum
([
tf
.
keras
.
backend
.
count_params
(
p
)
for
p
in
model
.
trainable_weights
]))
def
load_weights
(
model
:
tf
.
keras
.
Model
,
model_weights_path
:
Text
,
weights_format
:
Text
=
'saved_model'
):
"""Load model weights from the given file path.
Args:
model: the model to load weights into
model_weights_path: the path of the model weights
weights_format: the model weights format. One of 'saved_model', 'h5',
or 'checkpoint'.
"""
if
weights_format
==
'saved_model'
:
loaded_model
=
tf
.
keras
.
models
.
load_model
(
model_weights_path
)
model
.
set_weights
(
loaded_model
.
get_weights
())
else
:
model
.
load_weights
(
model_weights_path
)
official/vision/image_classification/efficientnet/efficientnet_config.py
View file @
bbdc9810
...
...
@@ -50,7 +50,7 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
model_params
:
Mapping
[
str
,
Any
]
=
dataclasses
.
field
(
default_factory
=
lambda
:
{
'model_name'
:
'efficientnet-b0'
,
'model_weights_path'
:
''
,
'
copy_to_local'
:
False
,
'
weights_format'
:
'saved_model'
,
'overrides'
:
{
'batch_norm'
:
'default'
,
'rescale_input'
:
True
,
...
...
official/vision/image_classification/efficientnet/efficientnet_model.py
View file @
bbdc9810
...
...
@@ -467,7 +467,7 @@ class EfficientNet(tf.keras.Model):
def
from_name
(
cls
,
model_name
:
Text
,
model_weights_path
:
Text
=
None
,
copy_to_local
:
bool
=
False
,
weights_format
:
Text
=
'saved_model'
,
overrides
:
Dict
[
Text
,
Any
]
=
None
):
"""Construct an EfficientNet model from a predefined model name.
...
...
@@ -476,7 +476,8 @@ class EfficientNet(tf.keras.Model):
Args:
model_name: the predefined model name
model_weights_path: the path to the weights (h5 file or saved model dir)
copy_to_local: copy the weights to a local tmp dir
weights_format: the model weights format. One of 'saved_model', 'h5',
or 'checkpoint'.
overrides: (optional) a dict containing keys that can override config
Returns:
...
...
@@ -496,12 +497,8 @@ class EfficientNet(tf.keras.Model):
model
=
cls
(
config
=
config
,
overrides
=
overrides
)
if
model_weights_path
:
if
copy_to_local
:
tmp_file
=
os
.
path
.
join
(
'/tmp'
,
model_name
+
'.h5'
)
model_weights_file
=
os
.
path
.
join
(
model_weights_path
,
'model.h5'
)
tf
.
io
.
gfile
.
copy
(
model_weights_file
,
tmp_file
,
overwrite
=
True
)
model_weights_path
=
tmp_file
model
.
load_weights
(
model_weights_path
)
common_modules
.
load_weights
(
model
,
model_weights_path
,
weights_format
=
weights_format
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment