".github/vscode:/vscode.git/clone" did not exist on "a0f844ed5a89159d4e637827ccdbd1e2d175798e"
Commit c86a93db authored by Vishnu Banna's avatar Vishnu Banna
Browse files

optimization package pr comments

parent 17f4ae11
......@@ -86,32 +86,6 @@ class Yolo(tf.keras.Model):
def from_config(cls, config):
return cls(**config)
def get_weight_groups(self, train_vars):
"""Sort the list of trainable variables into groups for optimization.
Args:
train_vars: a list of tf.Variables that need to get sorted into their
respective groups.
Returns:
weights: a list of tf.Variables for the weights.
bias: a list of tf.Variables for the bias.
other: a list of tf.Variables for the other operations.
"""
bias = []
weights = []
other = []
for var in train_vars:
if "bias" in var.name:
bias.append(var)
elif "beta" in var.name:
bias.append(var)
elif "kernel" in var.name or "weight" in var.name:
weights.append(var)
else:
other.append(var)
return weights, bias, other
def fuse(self):
"""Fuses all Convolution and Batchnorm layers to get better latency."""
print("Fusing Conv Batch Norm Layers.")
......
......@@ -388,9 +388,7 @@ class YoloTask(base_task.Task):
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
optimizer.set_bias_lr(
opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr))
weights, biases, others = self._model.get_weight_groups(
self._model.trainable_variables)
optimizer.set_variable_groups(weights, biases, others)
optimizer.search_and_set_variable_groups(self._model.trainable_variables)
else:
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
opt_factory._use_ema = ema
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment