wrapper.py 435 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13

from torch.utils.data import DataLoader
from functools import wraps

def wrap_output(dataloader, output_wrapper):
    def wrapped_collate_fn(old_collate_fn):
        @wraps(old_collate_fn)
        def new_collate_fn(input_):
            output = old_collate_fn(input_)
            return output_wrapper(*output)
        return new_collate_fn
    dataloader.collate_fn = wrapped_collate_fn(dataloader.collate_fn)
    return dataloader