Commit c86a93db authored by Vishnu Banna's avatar Vishnu Banna
Browse files

optimization package pr comments

parent 17f4ae11
...@@ -86,32 +86,6 @@ class Yolo(tf.keras.Model): ...@@ -86,32 +86,6 @@ class Yolo(tf.keras.Model):
def from_config(cls, config): def from_config(cls, config):
return cls(**config) return cls(**config)
def get_weight_groups(self, train_vars):
"""Sort the list of trainable variables into groups for optimization.
Args:
train_vars: a list of tf.Variables that need to get sorted into their
respective groups.
Returns:
weights: a list of tf.Variables for the weights.
bias: a list of tf.Variables for the bias.
other: a list of tf.Variables for the other operations.
"""
bias = []
weights = []
other = []
for var in train_vars:
if "bias" in var.name:
bias.append(var)
elif "beta" in var.name:
bias.append(var)
elif "kernel" in var.name or "weight" in var.name:
weights.append(var)
else:
other.append(var)
return weights, bias, other
def fuse(self): def fuse(self):
"""Fuses all Convolution and Batchnorm layers to get better latency.""" """Fuses all Convolution and Batchnorm layers to get better latency."""
print("Fusing Conv Batch Norm Layers.") print("Fusing Conv Batch Norm Layers.")
......
...@@ -388,9 +388,7 @@ class YoloTask(base_task.Task): ...@@ -388,9 +388,7 @@ class YoloTask(base_task.Task):
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
optimizer.set_bias_lr( optimizer.set_bias_lr(
opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr)) opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr))
weights, biases, others = self._model.get_weight_groups( optimizer.search_and_set_variable_groups(self._model.trainable_variables)
self._model.trainable_variables)
optimizer.set_variable_groups(weights, biases, others)
else: else:
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
opt_factory._use_ema = ema opt_factory._use_ema = ema
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment