dataloader.py 626 Bytes
Newer Older
1
2
"""DataLoader utils."""
from gluoncv.data.batchify import Pad
3
4
5
6
from mxnet import nd

import dgl

7
8
9
10
11

def dgl_mp_batchify_fn(data):
    if isinstance(data[0], tuple):
        data = zip(*data)
        return [dgl_mp_batchify_fn(i) for i in data]
12

13
14
15
16
17
18
19
20
    for dt in data:
        if dt is not None:
            if isinstance(dt, dgl.DGLGraph):
                return [d for d in data if isinstance(d, dgl.DGLGraph)]
            elif isinstance(dt, nd.NDArray):
                pad = Pad(axis=(1, 2), num_shards=1, ret_length=False)
                data_list = [dt for dt in data if dt is not None]
                return pad(data_list)