dataloader.py 625 Bytes
Newer Older
1
"""DataLoader utils."""
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
2
import dgl
3
from gluoncv.data.batchify import Pad
4
5
from mxnet import nd

6
7
8
9
10

def dgl_mp_batchify_fn(data):
    if isinstance(data[0], tuple):
        data = zip(*data)
        return [dgl_mp_batchify_fn(i) for i in data]
11

12
13
14
15
16
17
18
19
    for dt in data:
        if dt is not None:
            if isinstance(dt, dgl.DGLGraph):
                return [d for d in data if isinstance(d, dgl.DGLGraph)]
            elif isinstance(dt, nd.NDArray):
                pad = Pad(axis=(1, 2), num_shards=1, ret_length=False)
                data_list = [dt for dt in data if dt is not None]
                return pad(data_list)