file_client.py 3.54 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py  # noqa: E501
import numpy as np
import os.path as osp
from abc import ABCMeta, abstractmethod


class BaseStorageBackend(metaclass=ABCMeta):
    """Abstract class of storage backends.

    All backends need to implement two apis: ``get()`` and ``get_text()``.
    ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
    as texts.
    """

    @abstractmethod
    def get(self, filepath):
        pass

    @abstractmethod
    def get_text(self, filepath):
        pass

class Hdf5Backend(BaseStorageBackend):

    def __init__(self, h5_paths, client_keys='default', h5_clip='default', **kwargs):
        try:
            import h5py
        except ImportError:
            raise ImportError('Please install h5py to enable Hdf5Backend.')

        if isinstance(client_keys, str):
            client_keys = [client_keys]

        if isinstance(h5_paths, list):
            self.h5_paths = [str(v) for v in h5_paths]
        elif isinstance(h5_paths, str):
            self.h5_paths = [str(h5_paths)]
        assert len(client_keys) == len(self.h5_paths), ('client_keys and db_paths should have the same length, '
                                                        f'but received {len(client_keys)} and {len(self.h5_paths)}.')

        self._client = {}
        for client, path in zip(client_keys, self.h5_paths):
            self._client[client] = h5py.File(osp.join(path, h5_clip), 'r')

    def get(self, filepath):

        ## filepath is neighor_list contains num_frame image keys
        ## get LQ
        file_lr = self._client['LR']
        file_hr = self._client['HR']
        img_lrs = []
        img_hrs = []

        for idx in filepath:
            img_lr = file_lr[f'images/{idx:06d}'][:].astype(np.float32) / 255.
            img_lrs.append(img_lr)

            img_hr = file_hr[f'images/{idx:06d}'][:].astype(np.float32) / 255.
            img_hrs.append(img_hr)

        # get bidirectional voxels to event_lqs
        event_lqs = []
        voxels_f = []
        voxels_b = []

        for idx in filepath[:-1]:
            voxel_f = file_lr[f'voxels_f/{idx:06d}'][:].astype(np.float32)
            voxel_b = file_lr[f'voxels_b/{idx:06d}'][:].astype(np.float32)
            voxels_f.append(voxel_f)
            voxels_b.append(voxel_b)

        event_lqs.extend(voxels_f)
        event_lqs.extend(voxels_b)

        assert len(voxels_f) == len(voxels_b) == len(filepath) - 1

        return img_lrs, img_hrs, event_lqs

    def get_text(self, filepath):
        raise NotImplementedError


class FileClient(object):
    """A general file client to access files in different backend.

    The client loads a file or text in a specified backend from its path
    and return it as a binary file. it can also register other backend
    accessor with a given name and backend class.

    Attributes:
        backend (str): The storage backend type. Support options is "hdf5"
        client (:obj:`BaseStorageBackend`): The backend object.
    """

    _backends = {
        'hdf5': Hdf5Backend,
    }

    def __init__(self, backend='disk', **kwargs):
        if backend not in self._backends:
            raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
                             f' are {list(self._backends.keys())}')
        self.backend = backend
        self.client = self._backends[backend](**kwargs)

    def get(self, filepath):
        return self.client.get(filepath)

    def get_text(self, filepath):
        return self.client.get_text(filepath)