generate_tusimple_dataset.py 4.49 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import cv2
import json
import argparse
import numpy as np

TRAIN_DATA_SET = [
    'label_data_0313.json', 'label_data_0601.json', 'label_data_0531.json'
]
TEST_DATA_SET = ['test_label.json']

Image_Height, Image_Width = 720, 1280
# lane width (pixel)
LANE_SEG_WIDTH = 30
# has 7 class instances
CLASS_NUMS = 7


def generate_seg_label_image(label):
    # this code copy from https://github.com/ZJULearning/resa/blob/main/tools/generate_seg_tusimple.py
    lanes = []
    _lanes = []
    slope = []  # identify 0th, 1st, 2nd, 3rd, 4th, 5th lane through slope
    for i in range(len(label['lanes'])):
        l = [(x, y) for x, y in zip(label['lanes'][i], label['h_samples'])
             if x >= 0]
        if (len(l) > 1):
            _lanes.append(l)
            slope.append(
                np.arctan2(l[-1][1] - l[0][1], l[0][0] - l[-1][0]) / np.pi *
                180)
    _lanes = [_lanes[i] for i in np.argsort(slope)]
    slope = [slope[i] for i in np.argsort(slope)]

    idx = [None for i in range(CLASS_NUMS - 1)]
    for i in range(len(slope)):
        if slope[i] <= 90:
            idx[2] = i
            idx[1] = i - 1 if i > 0 else None
            idx[0] = i - 2 if i > 1 else None
        else:
            idx[3] = i
            idx[4] = i + 1 if i + 1 < len(slope) else None
            idx[5] = i + 2 if i + 2 < len(slope) else None
            break
    for i in range(CLASS_NUMS - 1):
        lanes.append([] if idx[i] is None else _lanes[idx[i]])

    seg_img = np.zeros([Image_Height, Image_Width], np.uint8)

    for i in range(len(lanes)):
        coords = lanes[i]
        if len(coords) < 4:
            continue
        for j in range(len(coords) - 1):
            cv2.line(seg_img, coords[j], coords[j + 1], i + 1,
                     LANE_SEG_WIDTH // 2)

    return seg_img


def generate_labels(args, src_dir, label_dir, image_set, mode):
    os.makedirs(os.path.join(args.root, src_dir, label_dir), exist_ok=True)
    label_file = open(
        os.path.join(args.root, src_dir, "{}_list.txt".format(mode)), "w")
    for json_name in (image_set):
        with open(os.path.join(args.root, src_dir, json_name)) as jsonfile:
            for jsonline in jsonfile:
                label = json.loads(jsonline)
                seg_img = generate_seg_label_image(label)
                img_path = label['raw_file']
                seg_path = img_path.split("/")
                seg_path, img_name = os.path.join(args.root, src_dir, label_dir,
                                                  seg_path[1],
                                                  seg_path[2]), seg_path[3]
                os.makedirs(seg_path, exist_ok=True)
                seg_path = os.path.join(seg_path, img_name[:-3] + "png")
                cv2.imwrite(seg_path, seg_img)

                img_path = "/".join([src_dir, img_path])
                seg_path = "/".join([
                    src_dir, label_dir, *img_path.split("/")[2:4],
                    img_name[:-3] + "png"
                ])
                if seg_path[0] != '/':
                    seg_path = '/' + seg_path
                if img_path[0] != '/':
                    img_path = '/' + img_path

                label_str = []
                label_str.insert(0, seg_path)
                label_str.insert(0, img_path)
                label_str = " ".join(label_str) + "\n"
                label_file.write(label_str)


def process_tusimple_dataset(args):
    print("generating train set...")
    generate_labels(args, "train_set", "labels", TRAIN_DATA_SET, mode="train")
    print("generating test set...")
    generate_labels(args, "test_set", "labels", TEST_DATA_SET, mode="test")
    print("generate finish!")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--root',
        type=str,
        default=None,
        help='The origin path of unzipped tusimple dataset')
    args = parser.parse_args()

    process_tusimple_dataset(args)