"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "d0063f3d83beac01e85f3027c4de6499a8985469"
Commit 24ade5b8 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

yolo_model update

parent e528aa76
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import tensorflow as tf import tensorflow as tf
# Static base Yolo Models that do not require configuration # static base Yolo Models that do not require configuration
# similar to a backbone model id. # similar to a backbone model id.
# this is done greatly simplify the model config # this is done greatly simplify the model config
...@@ -80,31 +80,31 @@ class Yolo(tf.keras.Model): ...@@ -80,31 +80,31 @@ class Yolo(tf.keras.Model):
backbone=None, backbone=None,
decoder=None, decoder=None,
head=None, head=None,
detection_generator=None, filter=None,
**kwargs): **kwargs):
"""Detection initialization function. """Detection initialization function.
Args: Args:
backbone: `tf.keras.Model`, a backbone network. backbone: `tf.keras.Model` a backbone network.
decoder: `tf.keras.Model`, a decoder network. decoder: `tf.keras.Model` a decoder network.
head: `YoloHead`, the YOLO head. head: `RetinaNetHead`, the RetinaNet head.
detection_generator: `tf.keras.Model`, the detection generator. filter: the detection generator.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
super().__init__(**kwargs) super(Yolo, self).__init__(**kwargs)
self._config_dict = { self._config_dict = {
"backbone": backbone, 'backbone': backbone,
"decoder": decoder, 'decoder': decoder,
"head": head, 'head': head,
"detection_generator": detection_generator 'filter': filter
} }
# model components # model components
self._backbone = backbone self._backbone = backbone
self._decoder = decoder self._decoder = decoder
self._head = head self._head = head
self._detection_generator = detection_generator self._filter = filter
return
def call(self, inputs, training=False): def call(self, inputs, training=False):
maps = self._backbone(inputs) maps = self._backbone(inputs)
...@@ -114,7 +114,7 @@ class Yolo(tf.keras.Model): ...@@ -114,7 +114,7 @@ class Yolo(tf.keras.Model):
return {"raw_output": raw_predictions} return {"raw_output": raw_predictions}
else: else:
# Post-processing. # Post-processing.
predictions = self._detection_generator(raw_predictions) predictions = self._filter(raw_predictions)
predictions.update({"raw_output": raw_predictions}) predictions.update({"raw_output": raw_predictions})
return predictions return predictions
...@@ -131,8 +131,8 @@ class Yolo(tf.keras.Model): ...@@ -131,8 +131,8 @@ class Yolo(tf.keras.Model):
return self._head return self._head
@property @property
def detection_generator(self): def filter(self):
return self._detection_generator return self._filter
def get_config(self): def get_config(self):
return self._config_dict return self._config_dict
...@@ -140,3 +140,29 @@ class Yolo(tf.keras.Model): ...@@ -140,3 +140,29 @@ class Yolo(tf.keras.Model):
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
return cls(**config) return cls(**config)
def get_weight_groups(self, train_vars):
"""Sort the list of trainable variables into groups for optimization.
Args:
train_vars: a list of tf.Variables that need to get sorted into their
respective groups.
Returns:
weights: a list of tf.Variables for the weights.
bias: a list of tf.Variables for the bias.
other: a list of tf.Variables for the other operations.
"""
bias = []
weights = []
other = []
for var in train_vars:
if "bias" in var.name:
bias.append(var)
elif "beta" in var.name:
bias.append(var)
elif "kernel" in var.name or "weight" in var.name:
weights.append(var)
else:
other.append(var)
return weights, bias, other
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment