test_fileio.py 7.28 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
import os
import os.path as osp
4
import sys
5
import tempfile
6
from unittest.mock import MagicMock, patch
7
8
9

import pytest

Kai Chen's avatar
Kai Chen committed
10
import mmcv
11
12
13
14
from mmcv.fileio.file_client import HTTPBackend, PetrelBackend

sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
Kai Chen's avatar
Kai Chen committed
15

16

17
def _test_handler(file_format, test_obj, str_checker, mode='r+'):
18
19
20
21
    # dump to a string
    dump_str = mmcv.dump(test_obj, file_format=file_format)
    str_checker(dump_str)

22
    # load/dump with filenames from disk
23
24
25
26
27
28
29
    tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test_dump')
    mmcv.dump(test_obj, tmp_filename, file_format=file_format)
    assert osp.isfile(tmp_filename)
    load_obj = mmcv.load(tmp_filename, file_format=file_format)
    assert load_obj == test_obj
    os.remove(tmp_filename)

30
31
32
33
34
35
36
    # load/dump with filename from petrel
    method = 'put' if 'b' in mode else 'put_text'
    with patch.object(PetrelBackend, method, return_value=None) as mock_method:
        filename = 's3://path/of/your/file'
        mmcv.dump(test_obj, filename, file_format=file_format)
    mock_method.assert_called()

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    # json load/dump with a file-like object
    with tempfile.NamedTemporaryFile(mode, delete=False) as f:
        tmp_filename = f.name
        mmcv.dump(test_obj, f, file_format=file_format)
    assert osp.isfile(tmp_filename)
    with open(tmp_filename, mode) as f:
        load_obj = mmcv.load(f, file_format=file_format)
    assert load_obj == test_obj
    os.remove(tmp_filename)

    # automatically inference the file format from the given filename
    tmp_filename = osp.join(tempfile.gettempdir(),
                            'mmcv_test_dump.' + file_format)
    mmcv.dump(test_obj, tmp_filename)
    assert osp.isfile(tmp_filename)
    load_obj = mmcv.load(tmp_filename)
    assert load_obj == test_obj
    os.remove(tmp_filename)


obj_for_test = [{'a': 'abc', 'b': 1}, 2, 'c']


def test_json():

    def json_checker(dump_str):
        assert dump_str in [
            '[{"a": "abc", "b": 1}, 2, "c"]', '[{"b": 1, "a": "abc"}, 2, "c"]'
        ]

67
    _test_handler('json', obj_for_test, json_checker)
68
69
70
71
72
73


def test_yaml():

    def yaml_checker(dump_str):
        assert dump_str in [
74
75
            '- {a: abc, b: 1}\n- 2\n- c\n', '- {b: 1, a: abc}\n- 2\n- c\n',
            '- a: abc\n  b: 1\n- 2\n- c\n', '- b: 1\n  a: abc\n- 2\n- c\n'
76
77
        ]

78
    _test_handler('yaml', obj_for_test, yaml_checker)
79
80
81
82
83
84
85
86


def test_pickle():

    def pickle_checker(dump_str):
        import pickle
        assert pickle.loads(dump_str) == obj_for_test

87
    _test_handler('pickle', obj_for_test, pickle_checker, mode='rb+')
88
89
90
91
92
93
94
95
96
97


def test_exception():
    test_obj = [{'a': 'abc', 'b': 1}, 2, 'c']

    with pytest.raises(ValueError):
        mmcv.dump(test_obj)

    with pytest.raises(TypeError):
        mmcv.dump(test_obj, 'tmp.txt')
Kai Chen's avatar
Kai Chen committed
98
99


100
101
def test_register_handler():

102
103
    @mmcv.register_handler('txt')
    class TxtHandler1(mmcv.BaseFileHandler):
104
105
106
107
108
109
110
111
112
113

        def load_from_fileobj(self, file):
            return file.read()

        def dump_to_fileobj(self, obj, file):
            file.write(str(obj))

        def dump_to_str(self, obj, **kwargs):
            return str(obj)

114
115
116
117
118
119
120
121
122
123
124
125
    @mmcv.register_handler(['txt1', 'txt2'])
    class TxtHandler2(mmcv.BaseFileHandler):

        def load_from_fileobj(self, file):
            return file.read()

        def dump_to_fileobj(self, obj, file):
            file.write('\n')
            file.write(str(obj))

        def dump_to_str(self, obj, **kwargs):
            return str(obj)
126
127
128
129
130

    content = mmcv.load(osp.join(osp.dirname(__file__), 'data/filelist.txt'))
    assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg'
    tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test.txt2')
    mmcv.dump(content, tmp_filename)
131
    with open(tmp_filename) as f:
132
133
        written = f.read()
    os.remove(tmp_filename)
134
    assert written == '\n' + content
135
136


Kai Chen's avatar
Kai Chen committed
137
def test_list_from_file():
138
    # get list from disk
Kai Chen's avatar
Kai Chen committed
139
140
141
142
143
144
145
146
147
148
149
150
    filename = osp.join(osp.dirname(__file__), 'data/filelist.txt')
    filelist = mmcv.list_from_file(filename)
    assert filelist == ['1.jpg', '2.jpg', '3.jpg', '4.jpg', '5.jpg']
    filelist = mmcv.list_from_file(filename, prefix='a/')
    assert filelist == ['a/1.jpg', 'a/2.jpg', 'a/3.jpg', 'a/4.jpg', 'a/5.jpg']
    filelist = mmcv.list_from_file(filename, offset=2)
    assert filelist == ['3.jpg', '4.jpg', '5.jpg']
    filelist = mmcv.list_from_file(filename, max_num=2)
    assert filelist == ['1.jpg', '2.jpg']
    filelist = mmcv.list_from_file(filename, offset=3, max_num=3)
    assert filelist == ['4.jpg', '5.jpg']

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    # get list from http
    with patch.object(
            HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'):
        filename = 'http://path/of/your/file'
        filelist = mmcv.list_from_file(
            filename, file_client_args={'backend': 'http'})
        assert filelist == ['1.jpg', '2.jpg', '3.jpg']
        filelist = mmcv.list_from_file(
            filename, file_client_args={'prefix': 'http'})
        assert filelist == ['1.jpg', '2.jpg', '3.jpg']
        filelist = mmcv.list_from_file(filename)
        assert filelist == ['1.jpg', '2.jpg', '3.jpg']

    # get list from petrel
    with patch.object(
            PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'):
        filename = 's3://path/of/your/file'
        filelist = mmcv.list_from_file(
            filename, file_client_args={'backend': 'petrel'})
        assert filelist == ['1.jpg', '2.jpg', '3.jpg']
        filelist = mmcv.list_from_file(
            filename, file_client_args={'prefix': 's3'})
        assert filelist == ['1.jpg', '2.jpg', '3.jpg']
        filelist = mmcv.list_from_file(filename)
        assert filelist == ['1.jpg', '2.jpg', '3.jpg']

Kai Chen's avatar
Kai Chen committed
177
178

def test_dict_from_file():
179
    # get dict from disk
Kai Chen's avatar
Kai Chen committed
180
181
182
183
184
    filename = osp.join(osp.dirname(__file__), 'data/mapping.txt')
    mapping = mmcv.dict_from_file(filename)
    assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
    mapping = mmcv.dict_from_file(filename, key_type=int)
    assert mapping == {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

    # get dict from http
    with patch.object(
            HTTPBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'):
        filename = 'http://path/of/your/file'
        mapping = mmcv.dict_from_file(
            filename, file_client_args={'backend': 'http'})
        assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
        mapping = mmcv.dict_from_file(
            filename, file_client_args={'prefix': 'http'})
        assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
        mapping = mmcv.dict_from_file(filename)
        assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}

    # get dict from petrel
    with patch.object(
            PetrelBackend, 'get_text',
            return_value='1 cat\n2 dog cow\n3 panda'):
        filename = 's3://path/of/your/file'
        mapping = mmcv.dict_from_file(
            filename, file_client_args={'backend': 'petrel'})
        assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
        mapping = mmcv.dict_from_file(
            filename, file_client_args={'prefix': 's3'})
        assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
        mapping = mmcv.dict_from_file(filename)
        assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}