Commit d1988e3e authored by Vishnu Banna's avatar Vishnu Banna
Browse files

optimization package pr comments

parent beeeed17
...@@ -126,6 +126,8 @@ class SGDTorch(tf.keras.optimizers.Optimizer): ...@@ -126,6 +126,8 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
def _search(self, var, keys): def _search(self, var, keys):
"""Search all all keys for matches. Return True on match.""" """Search all all keys for matches. Return True on match."""
if keys is not None:
# variable group is not ignored so search for the keys.
for r in keys: for r in keys:
if re.search(r, var.name) is not None: if re.search(r, var.name) is not None:
return True return True
...@@ -143,11 +145,11 @@ class SGDTorch(tf.keras.optimizers.Optimizer): ...@@ -143,11 +145,11 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
for var in variables: for var in variables:
# search for weights # search for weights
if self.search(var, self._weight_keys): if self._search(var, self._weight_keys):
weights.append(var) weights.append(var)
continue continue
# search for biases # search for biases
if self.search(var, self._bias_keys): if self._search(var, self._bias_keys):
biases.append(var) biases.append(var)
continue continue
# if all searches fail, add to other group # if all searches fail, add to other group
......
...@@ -388,9 +388,9 @@ class YoloTask(base_task.Task): ...@@ -388,9 +388,9 @@ class YoloTask(base_task.Task):
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
optimizer.set_bias_lr( optimizer.set_bias_lr(
opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr)) opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr))
# weights, biases, others = self._model.get_weight_groups( weights, biases, others = self._model.get_weight_groups(
# self._model.trainable_variables) self._model.trainable_variables)
# optimizer.set_variable_groups(weights, biases, others) optimizer.set_variable_groups(weights, biases, others)
else: else:
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
opt_factory._use_ema = ema opt_factory._use_ema = ema
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment