"...resnet50_tensorflow.git" did not exist on "443c074527f164955720dcde5c1830faf519f89f"
Commit 24ade5b8 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

yolo_model update

parent e528aa76
......@@ -17,7 +17,7 @@
import tensorflow as tf
# Static base Yolo Models that do not require configuration
# static base Yolo Models that do not require configuration
# similar to a backbone model id.
# this is done greatly simplify the model config
......@@ -80,31 +80,31 @@ class Yolo(tf.keras.Model):
backbone=None,
decoder=None,
head=None,
detection_generator=None,
filter=None,
**kwargs):
"""Detection initialization function.
Args:
backbone: `tf.keras.Model`, a backbone network.
decoder: `tf.keras.Model`, a decoder network.
head: `YoloHead`, the YOLO head.
detection_generator: `tf.keras.Model`, the detection generator.
backbone: `tf.keras.Model` a backbone network.
decoder: `tf.keras.Model` a decoder network.
head: `RetinaNetHead`, the RetinaNet head.
filter: the detection generator.
**kwargs: keyword arguments to be passed.
"""
super().__init__(**kwargs)
super(Yolo, self).__init__(**kwargs)
self._config_dict = {
"backbone": backbone,
"decoder": decoder,
"head": head,
"detection_generator": detection_generator
'backbone': backbone,
'decoder': decoder,
'head': head,
'filter': filter
}
# model components
self._backbone = backbone
self._decoder = decoder
self._head = head
self._detection_generator = detection_generator
self._filter = filter
return
def call(self, inputs, training=False):
maps = self._backbone(inputs)
......@@ -114,7 +114,7 @@ class Yolo(tf.keras.Model):
return {"raw_output": raw_predictions}
else:
# Post-processing.
predictions = self._detection_generator(raw_predictions)
predictions = self._filter(raw_predictions)
predictions.update({"raw_output": raw_predictions})
return predictions
......@@ -131,8 +131,8 @@ class Yolo(tf.keras.Model):
return self._head
@property
def detection_generator(self):
return self._detection_generator
def filter(self):
return self._filter
def get_config(self):
return self._config_dict
......@@ -140,3 +140,29 @@ class Yolo(tf.keras.Model):
@classmethod
def from_config(cls, config):
return cls(**config)
def get_weight_groups(self, train_vars):
"""Sort the list of trainable variables into groups for optimization.
Args:
train_vars: a list of tf.Variables that need to get sorted into their
respective groups.
Returns:
weights: a list of tf.Variables for the weights.
bias: a list of tf.Variables for the bias.
other: a list of tf.Variables for the other operations.
"""
bias = []
weights = []
other = []
for var in train_vars:
if "bias" in var.name:
bias.append(var)
elif "beta" in var.name:
bias.append(var)
elif "kernel" in var.name or "weight" in var.name:
weights.append(var)
else:
other.append(var)
return weights, bias, other
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment