Commit f24fc412 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 386029066
parent 1ec978a2
...@@ -377,11 +377,15 @@ def remove_ckpts(model_dir): ...@@ -377,11 +377,15 @@ def remove_ckpts(model_dir):
tf.io.gfile.remove(file_to_remove) tf.io.gfile.remove(file_to_remove)
def try_count_params(model: tf.keras.Model): def try_count_params(
model: Union[tf.Module, tf.keras.Model],
trainable_only: bool = False):
"""Count the number of parameters if model is possible. """Count the number of parameters if model is possible.
Args: Args:
model: Try to count the number of params in this model. model: Try to count the number of params in this model.
trainable_only: Whether to calculate trainable params only. This flag is
not used when the model has `count_params` attribute.
Returns: Returns:
The number of parameters or None. The number of parameters or None.
...@@ -395,7 +399,13 @@ def try_count_params(model: tf.keras.Model): ...@@ -395,7 +399,13 @@ def try_count_params(model: tf.keras.Model):
'because the model was not feed any input, e.g., the max ' 'because the model was not feed any input, e.g., the max '
'train step already reached before this run.') 'train step already reached before this run.')
return None return None
return None else:
total_params = 0
variables = model.trainable_variables if trainable_only else model.variables
for var in variables:
shape = tf.shape(var)
total_params += tf.math.reduce_prod(shape).numpy()
return total_params
def try_count_flops(model: Union[tf.Module, tf.keras.Model], def try_count_flops(model: Union[tf.Module, tf.keras.Model],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment