import time
import dgl
import torch

from .. import utils


@utils.benchmark('time')
@utils.parametrize('batch_size', [4, 32, 256])
@utils.parametrize('feat_size', [32, 128, 256])
@utils.parametrize('readout_op', ['sum', 'max', 'min', 'mean'])
@utils.parametrize('type', ['edge', 'node'])
def track_time(batch_size, feat_size, readout_op, type):
    device = utils.get_bench_device()
    ds = dgl.data.QM7bDataset()
    # prepare graph
    graphs = ds[0:batch_size][0]

    g = dgl.batch(graphs).to(device)
    if type == 'node':
        g.ndata['h'] = torch.randn((g.num_nodes(), feat_size), device=device)    
        t0 = time.time()
        for i in range(10):
            out = dgl.readout_nodes(g, 'h', readout_op)
        t1 = time.time()
    elif type == 'edge':
        g.edata['h'] = torch.randn((g.num_edges(), feat_size), device=device)
        t0 = time.time()
        for i in range(10):
            out = dgl.readout_edges(g, 'h', readout_op)
        t1 = time.time()
    else:
        raise Exception("Unknown type")

    return (t1 - t0) / 10
