parsing.py 3.36 KB
Newer Older
Ponku's avatar
Ponku committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
from functools import partial

import torch

from presets import StereoMatchingEvalPreset, StereoMatchingTrainPreset
from torchvision.datasets import (
    CarlaStereo,
    CREStereo,
    ETH3DStereo,
    FallingThingsStereo,
    InStereo2k,
    Kitti2012Stereo,
    Kitti2015Stereo,
    Middlebury2014Stereo,
    SceneFlowStereo,
    SintelStereo,
)

VALID_DATASETS = {
    "crestereo": partial(CREStereo),
    "carla-highres": partial(CarlaStereo),
    "instereo2k": partial(InStereo2k),
    "sintel": partial(SintelStereo),
    "sceneflow-monkaa": partial(SceneFlowStereo, variant="Monkaa", pass_name="both"),
    "sceneflow-flyingthings": partial(SceneFlowStereo, variant="FlyingThings3D", pass_name="both"),
    "sceneflow-driving": partial(SceneFlowStereo, variant="Driving", pass_name="both"),
    "fallingthings": partial(FallingThingsStereo, variant="both"),
    "eth3d-train": partial(ETH3DStereo, split="train"),
    "eth3d-test": partial(ETH3DStereo, split="test"),
    "kitti2015-train": partial(Kitti2015Stereo, split="train"),
    "kitti2015-test": partial(Kitti2015Stereo, split="test"),
    "kitti2012-train": partial(Kitti2012Stereo, split="train"),
    "kitti2012-test": partial(Kitti2012Stereo, split="train"),
    "middlebury2014-other": partial(
        Middlebury2014Stereo, split="additional", use_ambient_view=True, calibration="both"
    ),
    "middlebury2014-train": partial(Middlebury2014Stereo, split="train", calibration="perfect"),
    "middlebury2014-test": partial(Middlebury2014Stereo, split="test", calibration=None),
    "middlebury2014-train-ambient": partial(
        Middlebury2014Stereo, split="train", use_ambient_views=True, calibrartion="perfect"
    ),
}


def make_train_transform(args: argparse.Namespace) -> torch.nn.Module:
    return StereoMatchingTrainPreset(
        resize_size=args.resize_size,
        crop_size=args.crop_size,
        rescale_prob=args.rescale_prob,
        scaling_type=args.scaling_type,
        scale_range=args.scale_range,
        scale_interpolation_type=args.interpolation_strategy,
        use_grayscale=args.use_grayscale,
        mean=args.norm_mean,
        std=args.norm_std,
        horizontal_flip_prob=args.flip_prob,
        gpu_transforms=args.gpu_transforms,
        max_disparity=args.max_disparity,
        spatial_shift_prob=args.spatial_shift_prob,
        spatial_shift_max_angle=args.spatial_shift_max_angle,
        spatial_shift_max_displacement=args.spatial_shift_max_displacement,
        spatial_shift_interpolation_type=args.interpolation_strategy,
        gamma_range=args.gamma_range,
        brightness=args.brightness_range,
        contrast=args.contrast_range,
        saturation=args.saturation_range,
        hue=args.hue_range,
        asymmetric_jitter_prob=args.asymmetric_jitter_prob,
    )


def make_eval_transform(args: argparse.Namespace) -> torch.nn.Module:
    if args.eval_size is None:
        resize_size = args.crop_size
    else:
        resize_size = args.eval_size

    return StereoMatchingEvalPreset(
        mean=args.norm_mean,
        std=args.norm_std,
        use_grayscale=args.use_grayscale,
        resize_size=resize_size,
        interpolation_type=args.interpolation_strategy,
    )


def make_dataset(dataset_name: str, dataset_root: str, transforms: torch.nn.Module) -> torch.utils.data.Dataset:
    return VALID_DATASETS[dataset_name](root=dataset_root, transforms=transforms)