Commit f59c0d29 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

kmeans naming schemes and number of samples control

parent fb10543d
input resolution: [640, 640, 3]
boxes:
[[10.0, 17.0], [18.0, 55.0], [39.0, 26.0], [49.0, 76.0], [65.0, 181.0], [139.0, 93.0], [159.0, 305.0], [384.0, 210.0], [446.0, 486.0]]
\ No newline at end of file
......@@ -186,9 +186,10 @@ class AnchorBoxes(hyperparams.Config):
level_limits: Optional[List[int]] = None
anchors_per_scale: int = 3
generate_anchors: bool = False
scaling_mode: str = "sqrt_log"
generate_anchors: bool = True
scaling_mode: str = "sqrt"
box_generation_mode: str = "per_level"
num_samples: Optional[int] = None
def get(self, min_level, max_level):
"""Distribute them in order to each level.
......
......@@ -34,11 +34,13 @@ class AnchorKMeans:
def boxes(self):
return self._boxes.numpy()
def get_box_from_dataset(self, dataset, image_w=512):
def get_box_from_dataset(self, dataset, image_w=512, num_samples = None):
"""Load all the boxes in the dataset into memory."""
box_list = []
for i, sample in enumerate(dataset):
if num_samples is not None and i > num_samples:
break
width = sample["width"]
height = sample["height"]
boxes = sample['groundtruth_boxes']
......@@ -183,7 +185,8 @@ class AnchorKMeans:
anchors_per_scale = None,
scaling_mode = "sqrt_log",
box_generation_mode = "across_level",
image_resolution=[512, 512, 3]):
image_resolution=[512, 512, 3],
num_samples = None):
"""Run k-means on th eboxes for a given input resolution.
Args:
......@@ -192,7 +195,7 @@ class AnchorKMeans:
k: `int` for the number for centroids to generate.
anchors_per_scale: `int` for how many anchor boxes to use per level.
scaling_mode: `str` for the type of box scaling to used when generating
anchor boxes. Must be in the set {sqrt_log, default}.
anchor boxes. Must be in the set {sqrt, default}.
box_generation_mode: `str` for the type of kmeans to use when generating
anchor boxes. Must be in the set {across_level, per_level}.
image_resolution: `List[int]` for the resolution of the boxes to run
......@@ -202,9 +205,9 @@ class AnchorKMeans:
boxes: `List[List[int]]` of shape [k, 2] for the anchor boxes to use for
box predicitons.
"""
self.get_box_from_dataset(dataset)
self.get_box_from_dataset(dataset, num_samples = num_samples)
if scaling_mode == "sqrt_log":
if scaling_mode == "sqrt":
boxes_ls = tf.math.sqrt(self._boxes.numpy())
else:
boxes_ls = self._boxes.numpy()
......@@ -242,7 +245,7 @@ class AnchorKMeans:
dists = 1 - self.iou(boxes_ls, np.array(clusters))
assignments = tf.math.argmin(dists, axis=-1)
if scaling_mode == "sqrt_log":
if scaling_mode == "sqrt":
clusters = tf.square(clusters)
self._boxes *= tf.convert_to_tensor(image_resolution, self._boxes.dtype)
......@@ -260,9 +263,10 @@ class BoxGenInputReader(input_reader.InputReader):
def read(self,
k,
anchors_per_scale,
scaling_mode = "sqrt_log",
scaling_mode = "sqrt",
box_generation_mode = "across_level",
image_resolution=[512, 512, 3],
num_samples = None,
input_context=None):
"""Run k-means on th eboxes for a given input resolution.
......@@ -275,6 +279,9 @@ class BoxGenInputReader(input_reader.InputReader):
anchor boxes. Must be in the set {across_level, per_level}.
image_resolution: `List[int]` for the resolution of the boxes to run
k-means for.
num_samples: `Optional[int]` for the number of samples to use for kmeans,
typically about 5000 samples are all that are needed, but for the best
results use None to run the entire dataset.
Return:
boxes: `List[List[int]]` of shape [k, 2] for the anchor boxes to use for
......@@ -290,7 +297,8 @@ class BoxGenInputReader(input_reader.InputReader):
anchors_per_scale = anchors_per_scale,
image_resolution=image_resolution,
scaling_mode = scaling_mode,
box_generation_mode = box_generation_mode)
box_generation_mode = box_generation_mode,
num_samples = num_samples)
del kmeans_gen # free the memory
del dataset
......
......@@ -94,6 +94,7 @@ class YoloTask(base_task.Task):
input_context = input_context,
scaling_mode = anchor_cfg.scaling_mode,
box_generation_mode = anchor_cfg.box_generation_mode,
num_samples = anchor_cfg.num_samples
)
dataset.global_batch_size = gbs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment