Commit 4682f5c8 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

optimization package pr comments

parent dba326cc
......@@ -5,56 +5,48 @@ runtime:
task:
smart_bias_lr: 0.1
model:
darknet_based_model: True
input_size: [512, 512, 3]
darknet_based_model: False
input_size: [640, 640, 3]
backbone:
type: 'darknet'
darknet:
model_id: 'cspdarknet53'
model_id: 'altered_cspdarknet53'
max_level: 5
min_level: 3
decoder:
type: yolo_decoder
yolo_decoder:
version: v4
type: regular
activation: leaky
type: csp
head:
smart_bias: true
detection_generator:
box_type:
'all': original
'all': scaled
scale_xy:
'5': 1.05
'4': 1.1
'3': 1.2
max_boxes: 200
'all': 2.0
max_boxes: 300
nms_type: greedy
iou_thresh: 0.25
nms_thresh: 0.45
pre_nms_points: 500
loss:
use_scaled_loss: False
update_on_repeat: False
use_scaled_loss: true
update_on_repeat: true
box_loss_type:
'all': ciou
ignore_thresh:
'all': 0.7
iou_normalizer:
'all': 0.07
'all': 0.05
cls_normalizer:
'all': 1.0
'all': 0.3
object_normalizer:
'all': 1.0
'5': 0.28
'4': 0.70
'3': 2.80
objectness_smooth:
'all': 0.0
max_delta:
'all': 5.0
norm_activation:
activation: mish
norm_epsilon: 0.0001
norm_momentum: 0.99
use_sync_bn: true
'all': 1.0
num_classes: 80
anchor_boxes:
anchors_per_scale: 3
......@@ -63,11 +55,8 @@ task:
box: [142, 110], box: [192, 243], box: [459, 401]]
train_data:
global_batch_size: 1
# dtype: float32
input_path: '/media/vbanna/DATA_SHARE/CV/datasets/COCO_raw/records/train*'
# is_training: true
# drop_remainder: true
# seed: 1000
input_path: '/media/vbanna/DATA_SHARE/CV/datasets/COCO_raw/records/train*'
shuffle_buffer_size: 10000
parser:
mosaic:
mosaic_frequency: 1.0
......@@ -76,57 +65,13 @@ task:
mosaic_center: 0.25
aug_scale_min: 0.1
aug_scale_max: 1.9
max_num_instances: 200
letter_box: True
random_flip: True
aug_rand_translate: 0.1
random_pad: False
area_thresh: 0.1
validation_data:
# global_batch_size: 1
# dtype: float32
global_batch_size: 1
input_path: '/media/vbanna/DATA_SHARE/CV/datasets/COCO_raw/records/val*'
# is_training: false
# drop_remainder: true
# parser:
# max_num_instances: 200
# letter_box: True
# use_tie_breaker: True
# anchor_thresh: 0.213
# weight_decay: 0.000
# init_checkpoint: '../checkpoints/512-wd-baseline-e1'
# init_checkpoint_modules: 'all'
# annotation_file: null
trainer:
optimizer_config:
ema: null
# train_steps: 500500 # 160 epochs at 64 batchsize -> 500500 * 64/2
# validation_steps: 625
# steps_per_loop: 1850
# summary_interval: 1850
# validation_interval: 9250
# checkpoint_interval: 1850
# optimizer_config:
# ema:
# average_decay: 0.9998
# trainable_weights_only: False
# dynamic_decay: True
# learning_rate:
# type: stepwise
# stepwise:
# boundaries: [400000, 450000]
# name: PiecewiseConstantDecay
# values: [0.00131, 0.000131, 0.0000131]
# optimizer:
# type: sgd_torch
# sgd_torch:
# momentum: 0.949
# momentum_start: 0.949
# nesterov: True
# warmup_steps: 1000
# weight_decay: 0.0005
# sim_torch: true
# name: SGD
# warmup:
# type: 'linear'
# linear:
# warmup_steps: 1000 #learning rate rises from 0 to 0.0013 over 1000 steps
ema: null
\ No newline at end of file
......@@ -57,3 +57,7 @@ class SGDTorchConfig(BaseOptimizerConfig):
warmup_steps: int = 1000
weight_decay: float = 0.0
sim_torch: bool = False
weight_keys: List[str] = dataclasses.field(
default_factory=lambda:["kernel", "weight"])
bias_keys: List[str] = dataclasses.field(
default_factory=lambda:["bias", "beta"])
\ No newline at end of file
......@@ -4,6 +4,7 @@ from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.python.training import gen_training_ops
import tensorflow as tf
import re
import logging
__all__ = ['SGDTorch']
......@@ -49,8 +50,7 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
# paramter groups. An example of this variable search can be found in
# official/vision/beta/projects/yolo/modeling/yolo_model.py.
weights, biases, other = model.get_groups()
opt.set_params(weights, biases, other)
optimizer.search_and_set_variable_groups(model.trainable_variables)
# if the learning rate schedule on the biases are different. if lr is not set
# the default schedule used for weights will be used on the biases.
......@@ -73,6 +73,8 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
nesterov=False,
sim_torch=False,
name="SGD",
weight_keys=["kernel", "weight"],
bias_keys=["bias", "beta"],
**kwargs):
super(SGDTorch, self).__init__(name, **kwargs)
......@@ -105,6 +107,9 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
self.sim_torch = sim_torch
# weights, biases, other
self._weight_keys = weight_keys
self._bias_keys = bias_keys
self._variables_set = False
self._wset = set()
self._bset = set()
self._oset = set()
......@@ -117,21 +122,80 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
def set_other_lr(self, lr):
self._set_hyper("other_learning_rate", lr)
def set_params(self, weights, biases, others):
self._wset = set([_var_key(w) for w in weights])
self._bset = set([_var_key(b) for b in biases])
self._oset = set([_var_key(o) for o in others])
def search_and_set_variable_groups(self, variables):
"""Search all variable for matches ot each group.
Args:
variables: List[tf.Variable] from model.trainable_variables
"""
weights = []
biases = []
others = []
def search(var, keys):
"""Search all all keys for matches. Return True on match."""
for r in keys:
if re.search(r, var.name) is not None:
return True
return False
for var in variables:
# search for weights
if search(var, self._weight_keys):
weights.append(var)
continue
# search for biases
if search(var, self._bias_keys):
biases.append(var)
continue
# if all searches fail, add to other group
others.append(var)
self.set_variable_groups(weights, biases, others)
return
def set_variable_groups(self, weights, biases, others):
"""Alterantive to search and set allowing user to manually set each group.
This method is allows the user to bypass the weights, biases and others
search by key, and manually set the values for each group. This is the
safest alternative in cases where the variables cannot be grouped by
searching their names.
Args:
weights: List[tf.Variable] from model.trainable_variables
biases: List[tf.Variable] from model.trainable_variables
others: List[tf.Variable] from model.trainable_variables
"""
if self._variables_set:
logging.warning("set_variable_groups has been called again indicating"
"that the variable groups have already been set, they"
"will be updated.")
self._wset.update(set([_var_key(w) for w in weights]))
self._bset.update(set([_var_key(b) for b in biases]))
self._oset.update(set([_var_key(o) for o in others]))
# Log the number of objects in each group.
logging.info(
f"Weights: {len(weights)} Biases: {len(biases)} Others: {len(others)}")
f"Weights: {len(self._wset)} Biases: {len(self._bset)} Others: {len(self._oset)}")
self._variables_set = True
return
def _create_slots(self, var_list):
"""Create a momentum variable for each variable."""
if self._momentum:
for var in var_list:
self.add_slot(var, "momentum")
# check if trainable to support GPU EMA.
if var.trainable:
self.add_slot(var, "momentum")
if not self._variables_set:
# Fall back to automatically set the variables in case the user did not.
self.search_and_set_variable_groups(var_list)
self._variables_set = False
def _get_momentum(self, iteration):
"""Get the momentum value."""
momentum = self._get_hyper("momentum")
momentum_start = self._get_hyper("momentum_start")
momentum_warm_up_steps = tf.cast(
......@@ -250,6 +314,8 @@ class SGDTorch(tf.keras.optimizers.Optimizer):
lr = coefficients["other_lr_t"]
momentum = coefficients["momentum"]
tf.print(lr)
if self.sim_torch:
return self._apply(grad, var, weight_decay, momentum, lr)
else:
......
......@@ -388,10 +388,9 @@ class YoloTask(base_task.Task):
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
optimizer.set_bias_lr(
opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr))
weights, biases, others = self._model.get_weight_groups(
self._model.trainable_variables)
optimizer.set_params(weights, biases, others)
optimizer.set_variable_groups(weights, biases, others)
else:
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
opt_factory._use_ema = ema
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment