"vscode:/vscode.git/clone" did not exist on "4b6506718f0b5d6bd7fef7d94e15877a2b0f7d96"
Commit 3ebf3a66 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

task update

parent b6c5f6e6
......@@ -155,24 +155,14 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
# if all searches fail, add to other group
others.append(var)
self.set_variable_groups(weights, biases, others)
self._set_variable_groups(weights, biases, others)
return weights, biases, others
def set_variable_groups(self, weights, biases, others):
"""Alterantive to search and set allowing user to manually set each group.
This method is allows the user to bypass the weights, biases and others
search by key, and manually set the values for each group. This is the
safest alternative in cases where the variables cannot be grouped by
searching their names.
Args:
weights: List[tf.Variable] from model.trainable_variables
biases: List[tf.Variable] from model.trainable_variables
others: List[tf.Variable] from model.trainable_variables
"""
def _set_variable_groups(self, weights, biases, others):
"""Sets the variables to be used in each group."""
if self._variables_set:
logging.warning("set_variable_groups has been called again indicating"
logging.warning("_set_variable_groups has been called again indicating"
"that the variable groups have already been set, they"
"will be updated.")
self._wset.update(set([_var_key(w) for w in weights]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment