"sgl-kernel/include/utils.h" did not exist on "8d323e95e4406d5663725b177571757c1d402e1e"
datasets.py 1.84 KB
Newer Older
qianyj's avatar
qianyj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import absolute_import, division, print_function

import cv2
import os

import numpy as np

from torch.utils import data


class PFLDDatasets(data.Dataset):
    """ Dataset to manage the data loading, augmentation and generation. """

    def __init__(self, file_list, transforms=None, data_root="", img_size=112):
        """
        Parameters
        ----------
        file_list : list
            a list of file path and annotations
        transforms : function
            function for data augmentation
        data_root : str
            the root path of dataset
        img_size : int
            the size of image height or width
        """
        self.line = None
        self.path = None
        self.img_size = img_size
        self.land = None
        self.angle = None
        self.data_root = data_root
        self.transforms = transforms
        with open(file_list, "r") as f:
            self.lines = f.readlines()

    def __getitem__(self, index):
        """ Get the data sample and labels with the index. """
        self.line = self.lines[index].strip().split()
        # load image
        if self.data_root:
            self.img = cv2.imread(os.path.join(self.data_root, self.line[0]))
        else:
            self.img = cv2.imread(self.line[0])
        # resize
        self.img = cv2.resize(self.img, (self.img_size, self.img_size))
        # obtain gt labels
        self.land = np.asarray(self.line[1: (106 * 2 + 1)], dtype=np.float32)
        self.angle = np.asarray(self.line[(106 * 2 + 1):], dtype=np.float32)

        # augmentation
        if self.transforms:
            self.img = self.transforms(self.img)

        return self.img, self.land, self.angle

    def __len__(self):
        """ Get the size of dataset. """
        return len(self.lines)