io.md 6.54 KB
Newer Older
Zaida Zhou's avatar
Zaida Zhou committed
1
2
3
4
## 文件输入输出

文件输入输出模块提供了两个通用的 API 接口用于读取和保存不同格式的文件。

5
6
7
8
```{note}
在 v1.3.16 及之后的版本中,IO 模块支持从不同后端读取数据并支持将数据至不同后端。更多细节请访问 PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330)。
```

Zaida Zhou's avatar
Zaida Zhou committed
9
10
11
12
### 读取和保存数据

`mmcv` 提供了一个通用的 api 用于读取和保存数据,目前支持的格式有 json、yaml 和 pickle。

13
14
#### 从硬盘读取数据或者将数据保存至硬盘

Zaida Zhou's avatar
Zaida Zhou committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
```python
import mmcv

# 从文件中读取数据
data = mmcv.load('test.json')
data = mmcv.load('test.yaml')
data = mmcv.load('test.pkl')
# 从文件对象中读取数据
with open('test.json', 'r') as f:
    data = mmcv.load(f, file_format='json')

# 将数据序列化为字符串
json_str = mmcv.dump(data, file_format='json')

# 将数据保存至文件 (根据文件名后缀反推文件类型)
mmcv.dump(data, 'out.pkl')

# 将数据保存至文件对象
with open('test.yaml', 'w') as f:
    data = mmcv.dump(data, f, file_format='yaml')
```

37
38
39
40
41
42
43
44
45
46
47
48
49
50
#### 从其他后端加载或者保存至其他后端

```python
import mmcv

# 从 s3 文件读取数据
data = mmcv.load('s3://bucket-name/test.json')
data = mmcv.load('s3://bucket-name/test.yaml')
data = mmcv.load('s3://bucket-name/test.pkl')

# 将数据保存至 s3 文件 (根据文件名后缀反推文件类型)
mmcv.dump(data, 's3://bucket-name/out.pkl')
```

Zaida Zhou's avatar
Zaida Zhou committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
我们提供了易于拓展的方式以支持更多的文件格式。我们只需要创建一个继承自 `BaseFileHandler`
文件句柄类并将其注册到 `mmcv` 中即可。句柄类至少需要重写三个方法。

```python
import mmcv

# 支持为文件句柄类注册多个文件格式
# @mmcv.register_handler(['txt', 'log'])
@mmcv.register_handler('txt')
class TxtHandler1(mmcv.BaseFileHandler):

    def load_from_fileobj(self, file):
        return file.read()

    def dump_to_fileobj(self, obj, file):
        file.write(str(obj))

    def dump_to_str(self, obj, **kwargs):
        return str(obj)
```

72
`PickleHandler` 为例
Zaida Zhou's avatar
Zaida Zhou committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

```python
import pickle

class PickleHandler(mmcv.BaseFileHandler):

    def load_from_fileobj(self, file, **kwargs):
        return pickle.load(file, **kwargs)

    def load_from_path(self, filepath, **kwargs):
        return super(PickleHandler, self).load_from_path(
            filepath, mode='rb', **kwargs)

    def dump_to_str(self, obj, **kwargs):
        kwargs.setdefault('protocol', 2)
        return pickle.dumps(obj, **kwargs)

    def dump_to_fileobj(self, obj, file, **kwargs):
        kwargs.setdefault('protocol', 2)
        pickle.dump(obj, file, **kwargs)

    def dump_to_path(self, obj, filepath, **kwargs):
        super(PickleHandler, self).dump_to_path(
            obj, filepath, mode='wb', **kwargs)
```

### 读取文件并返回列表或字典

例如, `a.txt` 是文本文件,一共有5行内容。

```
a
b
c
d
e
```
110
#### 从硬盘读取
Zaida Zhou's avatar
Zaida Zhou committed
111

112
使用 `list_from_file` 读取 `a.txt`
Zaida Zhou's avatar
Zaida Zhou committed
113
114
115
116
117
118
119
120
121
122
123
124

```python
>>> mmcv.list_from_file('a.txt')
['a', 'b', 'c', 'd', 'e']
>>> mmcv.list_from_file('a.txt', offset=2)
['c', 'd', 'e']
>>> mmcv.list_from_file('a.txt', max_num=2)
['a', 'b']
>>> mmcv.list_from_file('a.txt', prefix='/mnt/')
['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e']
```

125
同样, `b.txt` 也是文本文件,一共有3行内容
Zaida Zhou's avatar
Zaida Zhou committed
126
127
128
129
130
131
132

```
1 cat
2 dog cow
3 panda
```

133
使用 `dict_from_file` 读取 `b.txt`
Zaida Zhou's avatar
Zaida Zhou committed
134
135
136
137
138
139
140

```python
>>> mmcv.dict_from_file('b.txt')
{'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
>>> mmcv.dict_from_file('b.txt', key_type=int)
{1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}
```
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

#### 从其他后端读取

使用 `list_from_file` 读取 `s3://bucket-name/a.txt`

```python
>>> mmcv.list_from_file('s3://bucket-name/a.txt')
['a', 'b', 'c', 'd', 'e']
>>> mmcv.list_from_file('s3://bucket-name/a.txt', offset=2)
['c', 'd', 'e']
>>> mmcv.list_from_file('s3://bucket-name/a.txt', max_num=2)
['a', 'b']
>>> mmcv.list_from_file('s3://bucket-name/a.txt', prefix='/mnt/')
['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e']
```

使用 `dict_from_file` 读取 `b.txt`

```python
>>> mmcv.dict_from_file('s3://bucket-name/b.txt')
{'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
>>> mmcv.dict_from_file('s3://bucket-name/b.txt', key_type=int)
{1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}
```

### 读取和保存权重文件

#### 从硬盘读取权重文件或者将权重文件保存至硬盘

我们可以通过下面的方式从磁盘读取权重文件或者将权重文件保存至磁盘

```python
import torch

filepath1 = '/path/of/your/checkpoint1.pth'
filepath2 = '/path/of/your/checkpoint2.pth'
# 从 filepath1 读取权重文件
checkpoint = torch.load(filepath1)
# 将权重文件保存至 filepath2
torch.save(checkpoint, filepath2)
```

MMCV 提供了很多后端,`HardDiskBackend` 是其中一个,我们可以通过它来读取或者保存权重文件。

```python
import io
from mmcv.fileio.file_client import HardDiskBackend

disk_backend = HardDiskBackend()
with io.BytesIO(disk_backend.get(filepath1)) as buffer:
    checkpoint = torch.load(buffer)
with io.BytesIO() as buffer:
    torch.save(checkpoint, f)
    disk_backend.put(f.getvalue(), filepath2)
```

如果我们想在接口中实现根据文件路径自动选择对应的后端,我们可以使用 `FileClient`
例如,我们想实现两个方法,分别是读取权重以及保存权重,它们需支持不同类型的文件路径,可以是磁盘路径,也可以是网络路径或者其他路径。

```python
from mmcv.fileio.file_client import FileClient

def load_checkpoint(path):
    file_client = FileClient.infer(uri=path)
    with io.BytesIO(file_client.get(path)) as buffer:
        checkpoint = torch.load(buffer)
    return checkpoint

def save_checkpoint(checkpoint, path):
    with io.BytesIO() as buffer:
        torch.save(checkpoint, buffer)
        file_client.put(buffer.getvalue(), path)

file_client = FileClient.infer_client(uri=filepath1)
checkpoint = load_checkpoint(filepath1)
save_checkpoint(checkpoint, filepath2)
```

#### 从网络远端读取权重文件

```{note}
目前只支持从网络远端读取权重文件,暂不支持将权重文件写入网络远端
```

```python
import io
import torch
from mmcv.fileio.file_client import HTTPBackend, FileClient

filepath = 'http://path/of/your/checkpoint.pth'
checkpoint = torch.utils.model_zoo.load_url(filepath)

http_backend = HTTPBackend()
with io.BytesIO(http_backend.get(filepath)) as buffer:
    checkpoint = torch.load(buffer)

file_client = FileClient.infer_client(uri=filepath)
with io.BytesIO(file_client.get(filepath)) as buffer:
    checkpoint = torch.load(buffer)
```