Commit f59c0d29 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

kmeans naming schemes and number of samples control

parent fb10543d
input resolution: [640, 640, 3]
boxes:
[[10.0, 17.0], [18.0, 55.0], [39.0, 26.0], [49.0, 76.0], [65.0, 181.0], [139.0, 93.0], [159.0, 305.0], [384.0, 210.0], [446.0, 486.0]]
\ No newline at end of file
...@@ -186,9 +186,10 @@ class AnchorBoxes(hyperparams.Config): ...@@ -186,9 +186,10 @@ class AnchorBoxes(hyperparams.Config):
level_limits: Optional[List[int]] = None level_limits: Optional[List[int]] = None
anchors_per_scale: int = 3 anchors_per_scale: int = 3
generate_anchors: bool = False generate_anchors: bool = True
scaling_mode: str = "sqrt_log" scaling_mode: str = "sqrt"
box_generation_mode: str = "per_level" box_generation_mode: str = "per_level"
num_samples: Optional[int] = None
def get(self, min_level, max_level): def get(self, min_level, max_level):
"""Distribute them in order to each level. """Distribute them in order to each level.
......
...@@ -34,11 +34,13 @@ class AnchorKMeans: ...@@ -34,11 +34,13 @@ class AnchorKMeans:
def boxes(self): def boxes(self):
return self._boxes.numpy() return self._boxes.numpy()
def get_box_from_dataset(self, dataset, image_w=512): def get_box_from_dataset(self, dataset, image_w=512, num_samples = None):
"""Load all the boxes in the dataset into memory.""" """Load all the boxes in the dataset into memory."""
box_list = [] box_list = []
for i, sample in enumerate(dataset): for i, sample in enumerate(dataset):
if num_samples is not None and i > num_samples:
break
width = sample["width"] width = sample["width"]
height = sample["height"] height = sample["height"]
boxes = sample['groundtruth_boxes'] boxes = sample['groundtruth_boxes']
...@@ -183,7 +185,8 @@ class AnchorKMeans: ...@@ -183,7 +185,8 @@ class AnchorKMeans:
anchors_per_scale = None, anchors_per_scale = None,
scaling_mode = "sqrt_log", scaling_mode = "sqrt_log",
box_generation_mode = "across_level", box_generation_mode = "across_level",
image_resolution=[512, 512, 3]): image_resolution=[512, 512, 3],
num_samples = None):
"""Run k-means on th eboxes for a given input resolution. """Run k-means on th eboxes for a given input resolution.
Args: Args:
...@@ -192,7 +195,7 @@ class AnchorKMeans: ...@@ -192,7 +195,7 @@ class AnchorKMeans:
k: `int` for the number for centroids to generate. k: `int` for the number for centroids to generate.
anchors_per_scale: `int` for how many anchor boxes to use per level. anchors_per_scale: `int` for how many anchor boxes to use per level.
scaling_mode: `str` for the type of box scaling to used when generating scaling_mode: `str` for the type of box scaling to used when generating
anchor boxes. Must be in the set {sqrt_log, default}. anchor boxes. Must be in the set {sqrt, default}.
box_generation_mode: `str` for the type of kmeans to use when generating box_generation_mode: `str` for the type of kmeans to use when generating
anchor boxes. Must be in the set {across_level, per_level}. anchor boxes. Must be in the set {across_level, per_level}.
image_resolution: `List[int]` for the resolution of the boxes to run image_resolution: `List[int]` for the resolution of the boxes to run
...@@ -202,9 +205,9 @@ class AnchorKMeans: ...@@ -202,9 +205,9 @@ class AnchorKMeans:
boxes: `List[List[int]]` of shape [k, 2] for the anchor boxes to use for boxes: `List[List[int]]` of shape [k, 2] for the anchor boxes to use for
box predicitons. box predicitons.
""" """
self.get_box_from_dataset(dataset) self.get_box_from_dataset(dataset, num_samples = num_samples)
if scaling_mode == "sqrt_log": if scaling_mode == "sqrt":
boxes_ls = tf.math.sqrt(self._boxes.numpy()) boxes_ls = tf.math.sqrt(self._boxes.numpy())
else: else:
boxes_ls = self._boxes.numpy() boxes_ls = self._boxes.numpy()
...@@ -242,7 +245,7 @@ class AnchorKMeans: ...@@ -242,7 +245,7 @@ class AnchorKMeans:
dists = 1 - self.iou(boxes_ls, np.array(clusters)) dists = 1 - self.iou(boxes_ls, np.array(clusters))
assignments = tf.math.argmin(dists, axis=-1) assignments = tf.math.argmin(dists, axis=-1)
if scaling_mode == "sqrt_log": if scaling_mode == "sqrt":
clusters = tf.square(clusters) clusters = tf.square(clusters)
self._boxes *= tf.convert_to_tensor(image_resolution, self._boxes.dtype) self._boxes *= tf.convert_to_tensor(image_resolution, self._boxes.dtype)
...@@ -260,9 +263,10 @@ class BoxGenInputReader(input_reader.InputReader): ...@@ -260,9 +263,10 @@ class BoxGenInputReader(input_reader.InputReader):
def read(self, def read(self,
k, k,
anchors_per_scale, anchors_per_scale,
scaling_mode = "sqrt_log", scaling_mode = "sqrt",
box_generation_mode = "across_level", box_generation_mode = "across_level",
image_resolution=[512, 512, 3], image_resolution=[512, 512, 3],
num_samples = None,
input_context=None): input_context=None):
"""Run k-means on th eboxes for a given input resolution. """Run k-means on th eboxes for a given input resolution.
...@@ -275,6 +279,9 @@ class BoxGenInputReader(input_reader.InputReader): ...@@ -275,6 +279,9 @@ class BoxGenInputReader(input_reader.InputReader):
anchor boxes. Must be in the set {across_level, per_level}. anchor boxes. Must be in the set {across_level, per_level}.
image_resolution: `List[int]` for the resolution of the boxes to run image_resolution: `List[int]` for the resolution of the boxes to run
k-means for. k-means for.
num_samples: `Optional[int]` for the number of samples to use for kmeans,
typically about 5000 samples are all that are needed, but for the best
results use None to run the entire dataset.
Return: Return:
boxes: `List[List[int]]` of shape [k, 2] for the anchor boxes to use for boxes: `List[List[int]]` of shape [k, 2] for the anchor boxes to use for
...@@ -290,7 +297,8 @@ class BoxGenInputReader(input_reader.InputReader): ...@@ -290,7 +297,8 @@ class BoxGenInputReader(input_reader.InputReader):
anchors_per_scale = anchors_per_scale, anchors_per_scale = anchors_per_scale,
image_resolution=image_resolution, image_resolution=image_resolution,
scaling_mode = scaling_mode, scaling_mode = scaling_mode,
box_generation_mode = box_generation_mode) box_generation_mode = box_generation_mode,
num_samples = num_samples)
del kmeans_gen # free the memory del kmeans_gen # free the memory
del dataset del dataset
......
...@@ -94,6 +94,7 @@ class YoloTask(base_task.Task): ...@@ -94,6 +94,7 @@ class YoloTask(base_task.Task):
input_context = input_context, input_context = input_context,
scaling_mode = anchor_cfg.scaling_mode, scaling_mode = anchor_cfg.scaling_mode,
box_generation_mode = anchor_cfg.box_generation_mode, box_generation_mode = anchor_cfg.box_generation_mode,
num_samples = anchor_cfg.num_samples
) )
dataset.global_batch_size = gbs dataset.global_batch_size = gbs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment