"include/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "c828e171752e67ebba33197ba758a4f24188efcf"
Commit c8e188f0 authored by mibaumgartner's avatar mibaumgartner
Browse files

improve toy dataset

parent 131a40e9
......@@ -86,7 +86,7 @@ Warning:
2. When running a training inside the container it is necessary to [increase the shared memory](https://stackoverflow.com/questions/30210362/how-to-increase-the-size-of-the-dev-shm-in-docker-container).
I tested the following configuration on my local workstation:
```bash
docker run --gpus all -v ${det_data}:/opt/data -v ${det_models}:/opt/models -it nndetection:0.1 --shm-size=24gb /bin/bash
docker run --gpus all -v ${det_data}:/opt/data -v ${det_models}:/opt/models -it --shm-size=24gb nndetection:0.1 /bin/bash
```
</details>
......@@ -118,8 +118,17 @@ Some of the labels were corrected in datasets which we converted and can be down
The `Reproducing Experiments` section has an overview of multiple guides which explain the preparation of the datasets.
## Toy Dataset
Running `nndet_example` will automatically generate an example dataset with 3D squares and sqaures with holes which can be used to test the installation or experiment with prototype code.
The problem is very easy and the final results should be near perfect.
Running `nndet_example` will automatically generate an example dataset with 3D squares and sqaures with holes which can be used to test the installation or experiment with prototype code (it is still necessary to run the other nndet commands to process/train/predict the dataset).
```bash
# create data to test installation/environment (10 train 10 test)
nndet_example
# create full dataset for prototyping (1000 train 1000 test)
nndet_example --full [--num_processes]
```
The full problem is very easy and the final results should be near perfect.
After running the generation script follow the `Planning`, `Training` and `Inference` instructions below to construct the whole nnDetection pipeline.
## Reproducing Experiments
......
......@@ -16,7 +16,10 @@ limitations under the License.
import os
import random
import argparse
from pathlib import Path
from multiprocessing import Pool
from itertools import repeat
import numpy as np
import SimpleITK as sitk
......@@ -41,11 +44,12 @@ dim = 3
image_size = [256, 256, 256]
object_size = [16, 32]
object_width = 4
num_images_tr = 10
num_images_ts = 10
def generate_image(image_dir, label_dir, idx):
random.seed(idx)
np.random.seed(idx)
logger.info(f"Generating case_{idx}")
selected_size = np.random.randint(object_size[0], object_size[1])
selected_class = np.random.randint(0, 2)
......@@ -100,8 +104,26 @@ def main():
Generate an example dataset for nnDetection to test the installation or
experiment with ideas.
"""
random.seed(0)
np.random.seed(0)
parser = argparse.ArgumentParser()
parser.add_argument(
'--full',
help="Increase size of dataset. "
"Default sizes train/test 10/10 and full 1000/1000.",
action='store_true',
)
parser.add_argument(
'--num_processes',
help="Use multiprocessing to create dataset.",
type=int,
default=0,
)
args = parser.parse_args()
full = args.full
num_processes = args.num_processes
num_images_tr = 1000 if full else 10
num_images_ts = 1000 if full else 10
meta = {
"task": f"Task000D{dim}_Example",
......@@ -128,11 +150,40 @@ def main():
labels_ts_dir = raw_splitted_dir / "labelsTs"
labels_ts_dir.mkdir(parents=True, exist_ok=True)
for idx in range(num_images_tr):
generate_image(images_tr_dir, labels_tr_dir, idx)
for idx in range(num_images_ts):
generate_image(images_ts_dir, labels_ts_dir, idx)
if num_processes == 0:
for idx in range(num_images_tr):
generate_image(
images_tr_dir,
labels_tr_dir,
idx,
)
for idx in range(num_images_tr, num_images_tr + num_images_ts):
generate_image(
images_ts_dir,
labels_ts_dir,
idx,
)
else:
logger.info("Using multiprocessing to create example dataset.")
with Pool(processes=num_processes) as p:
p.starmap(
generate_image,
zip(
repeat(images_tr_dir),
repeat(labels_tr_dir),
range(num_images_tr),
)
)
with Pool(processes=num_processes) as p:
p.starmap(
generate_image,
zip(
repeat(images_ts_dir),
repeat(labels_ts_dir),
range(num_images_tr, num_images_tr + num_images_ts),
)
)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment