Commit 3ebf3a66 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

task update

parent b6c5f6e6
...@@ -155,24 +155,14 @@ class SGDTorch(tf.keras.optimizers.Optimizer): ...@@ -155,24 +155,14 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
# if all searches fail, add to other group # if all searches fail, add to other group
others.append(var) others.append(var)
self.set_variable_groups(weights, biases, others) self._set_variable_groups(weights, biases, others)
return weights, biases, others return weights, biases, others
def set_variable_groups(self, weights, biases, others): def _set_variable_groups(self, weights, biases, others):
"""Alterantive to search and set allowing user to manually set each group. """Sets the variables to be used in each group."""
This method is allows the user to bypass the weights, biases and others
search by key, and manually set the values for each group. This is the
safest alternative in cases where the variables cannot be grouped by
searching their names.
Args:
weights: List[tf.Variable] from model.trainable_variables
biases: List[tf.Variable] from model.trainable_variables
others: List[tf.Variable] from model.trainable_variables
"""
if self._variables_set: if self._variables_set:
logging.warning("set_variable_groups has been called again indicating" logging.warning("_set_variable_groups has been called again indicating"
"that the variable groups have already been set, they" "that the variable groups have already been set, they"
"will be updated.") "will be updated.")
self._wset.update(set([_var_key(w) for w in weights])) self._wset.update(set([_var_key(w) for w in weights]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment