"...runtime/git@developer.sourcefind.cn:change/sglang.git" did not exist on "1ab6be1b2666eb77cc4f849e8bf7dfb7e1856f48"
Commit c86a93db authored by Vishnu Banna's avatar Vishnu Banna
Browse files

optimization package pr comments

parent 17f4ae11
...@@ -85,32 +85,6 @@ class Yolo(tf.keras.Model): ...@@ -85,32 +85,6 @@ class Yolo(tf.keras.Model):
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
return cls(**config) return cls(**config)
def get_weight_groups(self, train_vars):
"""Sort the list of trainable variables into groups for optimization.
Args:
train_vars: a list of tf.Variables that need to get sorted into their
respective groups.
Returns:
weights: a list of tf.Variables for the weights.
bias: a list of tf.Variables for the bias.
other: a list of tf.Variables for the other operations.
"""
bias = []
weights = []
other = []
for var in train_vars:
if "bias" in var.name:
bias.append(var)
elif "beta" in var.name:
bias.append(var)
elif "kernel" in var.name or "weight" in var.name:
weights.append(var)
else:
other.append(var)
return weights, bias, other
def fuse(self): def fuse(self):
"""Fuses all Convolution and Batchnorm layers to get better latency.""" """Fuses all Convolution and Batchnorm layers to get better latency."""
......
...@@ -388,9 +388,7 @@ class YoloTask(base_task.Task): ...@@ -388,9 +388,7 @@ class YoloTask(base_task.Task):
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
optimizer.set_bias_lr( optimizer.set_bias_lr(
opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr)) opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr))
weights, biases, others = self._model.get_weight_groups( optimizer.search_and_set_variable_groups(self._model.trainable_variables)
self._model.trainable_variables)
optimizer.set_variable_groups(weights, biases, others)
else: else:
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
opt_factory._use_ema = ema opt_factory._use_ema = ema
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment