"doc/vscode:/vscode.git/clone" did not exist on "9fecdbaf276c9399bcc3574eb2dd2107e3ab9599"
Unverified Commit fa5e69b1 authored by Geewook Kim's avatar Geewook Kim Committed by GitHub
Browse files

Merge pull request #95 from moonbings/update_random_split

Update random split logic for reproducibility
parents e83c78be 5762c6d3
...@@ -29,13 +29,13 @@ synthtiger -o ./outputs/SynthDoG_en -c 50 -w 4 -v template.py SynthDoG config_en ...@@ -29,13 +29,13 @@ synthtiger -o ./outputs/SynthDoG_en -c 50 -w 4 -v template.py SynthDoG config_en
. .
'quality': [50, 95], 'quality': [50, 95],
'short_size': [720, 1024]} 'short_size': [720, 1024]}
Generated 1 data Generated 1 data (task 3)
Generated 2 data Generated 2 data (task 0)
Generated 3 data Generated 3 data (task 1)
. .
. .
Generated 49 data Generated 49 data (task 48)
Generated 50 data Generated 50 data (task 49)
46.32 seconds elapsed 46.32 seconds elapsed
``` ```
...@@ -44,6 +44,7 @@ Some important arguments: ...@@ -44,6 +44,7 @@ Some important arguments:
- `-o` : directory path to save data. - `-o` : directory path to save data.
- `-c` : number of data to generate. - `-c` : number of data to generate.
- `-w` : number of workers. - `-w` : number of workers.
- `-s` : random seed.
- `-v` : print error messages. - `-v` : print error messages.
To generate ECJK samples: To generate ECJK samples:
......
...@@ -38,10 +38,10 @@ class SynthDoG(templates.Template): ...@@ -38,10 +38,10 @@ class SynthDoG(templates.Template):
**config.get("effect", {}), **config.get("effect", {}),
) )
# config for splits (output_filename, split_ratio etc) # config for splits
self.splits = ["train", "validation", "test"] self.splits = ["train", "validation", "test"]
self.split_indexes = [0, 0, 0] self.split_ratio = split_ratio
self.split_ratio = [sum(split_ratio[: i + 1]) for i in range(0, len(split_ratio))] self.split_indexes = np.random.choice(3, size=10000, p=split_ratio)
def generate(self): def generate(self):
landscape = np.random.rand() < self.landscape landscape = np.random.rand() < self.landscape
...@@ -88,19 +88,11 @@ class SynthDoG(templates.Template): ...@@ -88,19 +88,11 @@ class SynthDoG(templates.Template):
roi = data["roi"] roi = data["roi"]
# split # split
output_dirpath = os.path.join(root, "train") split_idx = self.split_indexes[idx % len(self.split_indexes)]
file_idx = idx output_dirpath = os.path.join(root, self.splits[split_idx])
split_prob = np.random.rand()
for _idx, (split, ratio) in enumerate(zip(self.splits, self.split_ratio)):
if split_prob < ratio:
output_dirpath = os.path.join(root, split)
file_idx = self.split_indexes[_idx]
self.split_indexes[_idx] += 1
break
# save image # save image
image_filename = f"image_{file_idx}.jpg" image_filename = f"image_{idx}.jpg"
image_filepath = os.path.join(output_dirpath, image_filename) image_filepath = os.path.join(output_dirpath, image_filename)
os.makedirs(os.path.dirname(image_filepath), exist_ok=True) os.makedirs(os.path.dirname(image_filepath), exist_ok=True)
image = Image.fromarray(image[..., :3].astype(np.uint8)) image = Image.fromarray(image[..., :3].astype(np.uint8))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment