Unverified Commit fa5e69b1 authored by Geewook Kim's avatar Geewook Kim Committed by GitHub
Browse files

Merge pull request #95 from moonbings/update_random_split

Update random split logic for reproducibility
parents e83c78be 5762c6d3
......@@ -29,13 +29,13 @@ synthtiger -o ./outputs/SynthDoG_en -c 50 -w 4 -v template.py SynthDoG config_en
.
'quality': [50, 95],
'short_size': [720, 1024]}
Generated 1 data
Generated 2 data
Generated 3 data
Generated 1 data (task 3)
Generated 2 data (task 0)
Generated 3 data (task 1)
.
.
Generated 49 data
Generated 50 data
Generated 49 data (task 48)
Generated 50 data (task 49)
46.32 seconds elapsed
```
......@@ -44,6 +44,7 @@ Some important arguments:
- `-o` : directory path to save data.
- `-c` : number of data to generate.
- `-w` : number of workers.
- `-s` : random seed.
- `-v` : print error messages.
To generate ECJK samples:
......
......@@ -38,10 +38,10 @@ class SynthDoG(templates.Template):
**config.get("effect", {}),
)
# config for splits (output_filename, split_ratio etc)
# config for splits
self.splits = ["train", "validation", "test"]
self.split_indexes = [0, 0, 0]
self.split_ratio = [sum(split_ratio[: i + 1]) for i in range(0, len(split_ratio))]
self.split_ratio = split_ratio
self.split_indexes = np.random.choice(3, size=10000, p=split_ratio)
def generate(self):
landscape = np.random.rand() < self.landscape
......@@ -88,19 +88,11 @@ class SynthDoG(templates.Template):
roi = data["roi"]
# split
output_dirpath = os.path.join(root, "train")
file_idx = idx
split_prob = np.random.rand()
for _idx, (split, ratio) in enumerate(zip(self.splits, self.split_ratio)):
if split_prob < ratio:
output_dirpath = os.path.join(root, split)
file_idx = self.split_indexes[_idx]
self.split_indexes[_idx] += 1
break
split_idx = self.split_indexes[idx % len(self.split_indexes)]
output_dirpath = os.path.join(root, self.splits[split_idx])
# save image
image_filename = f"image_{file_idx}.jpg"
image_filename = f"image_{idx}.jpg"
image_filepath = os.path.join(output_dirpath, image_filename)
os.makedirs(os.path.dirname(image_filepath), exist_ok=True)
image = Image.fromarray(image[..., :3].astype(np.uint8))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment