Commit 4c977771 authored by moonbings's avatar moonbings
Browse files

Update random split logic for reproducing

parent e83c78be
...@@ -38,10 +38,10 @@ class SynthDoG(templates.Template): ...@@ -38,10 +38,10 @@ class SynthDoG(templates.Template):
**config.get("effect", {}), **config.get("effect", {}),
) )
# config for splits (output_filename, split_ratio etc) # config for splits
self.splits = ["train", "validation", "test"] self.splits = ["train", "validation", "test"]
self.split_indexes = [0, 0, 0] self.split_ratio = split_ratio
self.split_ratio = [sum(split_ratio[: i + 1]) for i in range(0, len(split_ratio))] self.split_indexes = np.random.choice(3, size=10000, p=split_ratio)
def generate(self): def generate(self):
landscape = np.random.rand() < self.landscape landscape = np.random.rand() < self.landscape
...@@ -88,19 +88,11 @@ class SynthDoG(templates.Template): ...@@ -88,19 +88,11 @@ class SynthDoG(templates.Template):
roi = data["roi"] roi = data["roi"]
# split # split
output_dirpath = os.path.join(root, "train") split = self.split_indexes[idx % len(self.split_indexes)]
file_idx = idx output_dirpath = os.path.join(root, self.splits[split])
split_prob = np.random.rand()
for _idx, (split, ratio) in enumerate(zip(self.splits, self.split_ratio)):
if split_prob < ratio:
output_dirpath = os.path.join(root, split)
file_idx = self.split_indexes[_idx]
self.split_indexes[_idx] += 1
break
# save image # save image
image_filename = f"image_{file_idx}.jpg" image_filename = f"image_{idx}.jpg"
image_filepath = os.path.join(output_dirpath, image_filename) image_filepath = os.path.join(output_dirpath, image_filename)
os.makedirs(os.path.dirname(image_filepath), exist_ok=True) os.makedirs(os.path.dirname(image_filepath), exist_ok=True)
image = Image.fromarray(image[..., :3].astype(np.uint8)) image = Image.fromarray(image[..., :3].astype(np.uint8))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment