Commit d1988e3e authored by Vishnu Banna's avatar Vishnu Banna
Browse files

optimization package pr comments

parent beeeed17
......@@ -126,9 +126,11 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
def _search(self, var, keys):
"""Search all all keys for matches. Return True on match."""
for r in keys:
if re.search(r, var.name) is not None:
return True
if keys is not None:
# variable group is not ignored so search for the keys.
for r in keys:
if re.search(r, var.name) is not None:
return True
return False
def search_and_set_variable_groups(self, variables):
......@@ -143,11 +145,11 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
for var in variables:
# search for weights
if self.search(var, self._weight_keys):
if self._search(var, self._weight_keys):
weights.append(var)
continue
# search for biases
if self.search(var, self._bias_keys):
if self._search(var, self._bias_keys):
biases.append(var)
continue
# if all searches fail, add to other group
......
......@@ -388,9 +388,9 @@ class YoloTask(base_task.Task):
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
optimizer.set_bias_lr(
opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr))
# weights, biases, others = self._model.get_weight_groups(
# self._model.trainable_variables)
# optimizer.set_variable_groups(weights, biases, others)
weights, biases, others = self._model.get_weight_groups(
self._model.trainable_variables)
optimizer.set_variable_groups(weights, biases, others)
else:
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
opt_factory._use_ema = ema
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment