Commit bbdc9810 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal changes.

PiperOrigin-RevId: 304805715
parent 5cf005fd
...@@ -98,3 +98,21 @@ def count_params(model, trainable_only=True): ...@@ -98,3 +98,21 @@ def count_params(model, trainable_only=True):
else: else:
return int(np.sum([tf.keras.backend.count_params(p) return int(np.sum([tf.keras.backend.count_params(p)
for p in model.trainable_weights])) for p in model.trainable_weights]))
def load_weights(model: tf.keras.Model,
model_weights_path: Text,
weights_format: Text = 'saved_model'):
"""Load model weights from the given file path.
Args:
model: the model to load weights into
model_weights_path: the path of the model weights
weights_format: the model weights format. One of 'saved_model', 'h5',
or 'checkpoint'.
"""
if weights_format == 'saved_model':
loaded_model = tf.keras.models.load_model(model_weights_path)
model.set_weights(loaded_model.get_weights())
else:
model.load_weights(model_weights_path)
...@@ -50,7 +50,7 @@ class EfficientNetModelConfig(base_configs.ModelConfig): ...@@ -50,7 +50,7 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
model_params: Mapping[str, Any] = dataclasses.field(default_factory=lambda: { model_params: Mapping[str, Any] = dataclasses.field(default_factory=lambda: {
'model_name': 'efficientnet-b0', 'model_name': 'efficientnet-b0',
'model_weights_path': '', 'model_weights_path': '',
'copy_to_local': False, 'weights_format': 'saved_model',
'overrides': { 'overrides': {
'batch_norm': 'default', 'batch_norm': 'default',
'rescale_input': True, 'rescale_input': True,
......
...@@ -467,7 +467,7 @@ class EfficientNet(tf.keras.Model): ...@@ -467,7 +467,7 @@ class EfficientNet(tf.keras.Model):
def from_name(cls, def from_name(cls,
model_name: Text, model_name: Text,
model_weights_path: Text = None, model_weights_path: Text = None,
copy_to_local: bool = False, weights_format: Text = 'saved_model',
overrides: Dict[Text, Any] = None): overrides: Dict[Text, Any] = None):
"""Construct an EfficientNet model from a predefined model name. """Construct an EfficientNet model from a predefined model name.
...@@ -476,7 +476,8 @@ class EfficientNet(tf.keras.Model): ...@@ -476,7 +476,8 @@ class EfficientNet(tf.keras.Model):
Args: Args:
model_name: the predefined model name model_name: the predefined model name
model_weights_path: the path to the weights (h5 file or saved model dir) model_weights_path: the path to the weights (h5 file or saved model dir)
copy_to_local: copy the weights to a local tmp dir weights_format: the model weights format. One of 'saved_model', 'h5',
or 'checkpoint'.
overrides: (optional) a dict containing keys that can override config overrides: (optional) a dict containing keys that can override config
Returns: Returns:
...@@ -496,12 +497,8 @@ class EfficientNet(tf.keras.Model): ...@@ -496,12 +497,8 @@ class EfficientNet(tf.keras.Model):
model = cls(config=config, overrides=overrides) model = cls(config=config, overrides=overrides)
if model_weights_path: if model_weights_path:
if copy_to_local: common_modules.load_weights(model,
tmp_file = os.path.join('/tmp', model_name + '.h5') model_weights_path,
model_weights_file = os.path.join(model_weights_path, 'model.h5') weights_format=weights_format)
tf.io.gfile.copy(model_weights_file, tmp_file, overwrite=True)
model_weights_path = tmp_file
model.load_weights(model_weights_path)
return model return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment