dataloader.pyi 1.98 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from typing import Any, Callable, TypeVar, Generic, overload, Sequence, List, Optional
from . import Dataset, Sampler

T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
_worker_init_fn_t = Callable[[int], None]

# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
# See https://github.com/python/mypy/issues/3737.
_collate_fn_t = Callable[[List[T]], Any]

class DataLoader(Generic[T_co]):
    dataset: Dataset[T_co]
    batch_size: int
    num_workers: int
    pin_memory: bool
    drop_last: bool
    timeout: float

    @overload
    def __init__(self, dataset: Dataset[T_co], batch_size: int=..., shuffle: bool=...,
                 sampler: Optional[Sampler[int]]=..., num_workers: int=..., collate_fn: _collate_fn_t=...,
                 pin_memory: bool=..., drop_last: bool=..., timeout: float=...,
                 worker_init_fn: _worker_init_fn_t=...) -> None: ...
    @overload
    def __init__(self, dataset: Dataset[T_co], batch_sampler: Optional[Sampler[Sequence[int]]]=...,
                 num_workers: int=..., collate_fn: _collate_fn_t=..., pin_memory: bool=..., timeout: float=...,
                 worker_init_fn: _worker_init_fn_t=...) -> None: ...

    def __len__(self) -> int: ...
    # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
    # since '_BaseDataLoaderIter' references 'DataLoader'. In mypy 0.720 and newer a new semantic
    # analyzer is used that obviates the need for this but we leave the quoting in to support older
    # versions of mypy
    def __iter__(self) -> '_BaseDataLoaderIter':...

class _BaseDataLoaderIter:
    def __init__(self, loader: DataLoader) -> None:...
    def __len__(self) -> int: ...
    def __iter__(self) -> _BaseDataLoaderIter: ...
    def __next__(self) -> Any: ...