"pytorch/vscode:/vscode.git/clone" did not exist on "b8aa893b2e1adca16b3a266ff06672d46c068676"
Commit beeeed17 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

optimization package pr comments

parent 4682f5c8
......@@ -7,10 +7,10 @@ import tensorflow as tf
import re
import logging
__all__ = ['SGDTorch']
def _var_key(var):
try:
from keras.optimizer_v2.optimizer_v2 import _var_key
except:
def _var_key(var):
"""Key for representing a primary variable, for looking up slots.
In graph mode the name is derived from the var shared name.
In eager mode the name is derived from the var unique id.
......@@ -113,6 +113,8 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
self._wset = set()
self._bset = set()
self._oset = set()
if self.sim_torch:
logging.info(f"Pytorch SGD simulation: ")
logging.info(f"Weight Decay: {weight_decay}")
......@@ -122,8 +124,15 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
def set_other_lr(self, lr):
self._set_hyper("other_learning_rate", lr)
def _search(self, var, keys):
"""Search all all keys for matches. Return True on match."""
for r in keys:
if re.search(r, var.name) is not None:
return True
return False
def search_and_set_variable_groups(self, variables):
"""Search all variable for matches ot each group.
"""Search all variable for matches at each group.
Args:
variables: List[tf.Variable] from model.trainable_variables
......@@ -132,27 +141,20 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
biases = []
others = []
def search(var, keys):
"""Search all all keys for matches. Return True on match."""
for r in keys:
if re.search(r, var.name) is not None:
return True
return False
for var in variables:
# search for weights
if search(var, self._weight_keys):
if self.search(var, self._weight_keys):
weights.append(var)
continue
# search for biases
if search(var, self._bias_keys):
if self.search(var, self._bias_keys):
biases.append(var)
continue
# if all searches fail, add to other group
others.append(var)
self.set_variable_groups(weights, biases, others)
return
return weights, biases, others
def set_variable_groups(self, weights, biases, others):
"""Alterantive to search and set allowing user to manually set each group.
......@@ -181,6 +183,21 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
self._variables_set = True
return
def _get_variable_group(self, var, coefficients):
if self._variables_set:
# check which groups hold which varaibles, preset.
if (_var_key(var) in self._wset):
return True, False, False
elif (_var_key(var) in self._bset):
return False, True, False
else:
# search the variables at run time.
if self._search(var, self._weight_keys):
return True, False, False
elif self._search(var, self._bias_keys):
return False, True, False
return False, False, True
def _create_slots(self, var_list):
"""Create a momentum variable for each variable."""
if self._momentum:
......@@ -189,11 +206,6 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
if var.trainable:
self.add_slot(var, "momentum")
if not self._variables_set:
# Fall back to automatically set the variables in case the user did not.
self.search_and_set_variable_groups(var_list)
self._variables_set = False
def _get_momentum(self, iteration):
"""Get the momentum value."""
momentum = self._get_hyper("momentum")
......@@ -239,7 +251,7 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
return apply_state[(var_device, var_dtype)]
def _apply_tf(self, grad, var, weight_decay, momentum, lr):
"""Uses Tensorflow Optimizer with Weight decay SGDW."""
def decay_op(var, learning_rate, wd):
if self._weight_decay and wd > 0:
return var.assign_sub(
......@@ -263,6 +275,7 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
var=var.handle, alpha=lr, delta=grad, use_locking=self._use_locking)
def _apply(self, grad, var, weight_decay, momentum, lr):
"""Uses Pytorch Optimizer with Weight decay SGDW."""
dparams = grad
groups = []
......@@ -288,19 +301,12 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
groups.append(weight_update)
return tf.group(*groups)
def _get_vartype(self, var, coefficients):
if (_var_key(var) in self._wset):
return True, False, False
elif (_var_key(var) in self._bset):
return False, True, False
return False, False, True
def _run_sgd(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
self._fallback_apply_state(var_device, var_dtype))
weights, bias, others = self._get_vartype(var, coefficients)
weights, bias, others = self._get_variable_group(var, coefficients)
weight_decay = tf.zeros_like(coefficients["weight_decay"])
lr = coefficients["lr_t"]
if weights:
......@@ -314,8 +320,6 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
lr = coefficients["other_lr_t"]
momentum = coefficients["momentum"]
tf.print(lr)
if self.sim_torch:
return self._apply(grad, var, weight_decay, momentum, lr)
else:
......
......@@ -388,9 +388,9 @@ class YoloTask(base_task.Task):
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
optimizer.set_bias_lr(
opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr))
weights, biases, others = self._model.get_weight_groups(
self._model.trainable_variables)
optimizer.set_variable_groups(weights, biases, others)
# weights, biases, others = self._model.get_weight_groups(
# self._model.trainable_variables)
# optimizer.set_variable_groups(weights, biases, others)
else:
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
opt_factory._use_ema = ema
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment