module.pyi 4.4 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from ... import Tensor, device, dtype
from .. import Parameter
5
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, Generic, NamedTuple
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from collections import OrderedDict
from ...utils.hooks import RemovableHandle

_grad_t = Union[Tuple[Tensor, ...], Tensor]
# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
# the type of the subclass, not the looser type of `Module`.
T = TypeVar('T')
# We parameter modules by the return type of its `forward` (and therefore `__call__`) method. This allows
# type inference to infer that the return value of calling a module in the canonical way (via `__call__)` is the
# same as the custom `forward` function of the submodule. Submodules tha wish to opt in this functionality be 
# defined as eg class ReturnsTwoTensors(Module[Tuple[Tensor, Tensor]]): ...
T_co = TypeVar('T_co', covariant=True)


class Module(Generic[T_co]):
    def __init__(self) -> None: ...

Min Xu's avatar
Min Xu committed
24
    def forward(self, *input: Any, **kwargs: Any) -> T_co: ...
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
25

Min Xu's avatar
Min Xu committed
26
    def __call__(self, *input: Any, **kwargs: Any) -> T_co: ...
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
27

Tom Birch's avatar
Tom Birch committed
28
    def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None: ...
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
29

Tom Birch's avatar
Tom Birch committed
30
    def register_parameter(self, name: str, param: Optional[Parameter]) -> None: ...
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
31
32
33
34
35

    def add_module(self, name: str, module: 'Module') -> None: ...

    def apply(self: T, fn: Callable[['Module'], None]) -> T: ...

Tom Birch's avatar
Tom Birch committed
36
    def cuda(self: T, device: Optional[Union[int, str, device]] = ...) -> T: ...
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

    def cpu(self: T) -> T: ...

    def type(self: T, dst_type: Union[dtype, str]) -> T: ...

    def float(self: T) -> T: ...

    def double(self: T) -> T: ...

    def half(self: T) -> T: ...

    @overload
    def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
           non_blocking: bool = ...) -> T: ...

    @overload
    def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...

    @overload
    def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...

    def register_backward_hook(self, hook: Callable[
        ['Module', _grad_t, _grad_t], Union[None, Tensor]]) -> RemovableHandle: ...

    # The hook takes a module as a first argument and variadic arguments after that, but there is no way to express that
    def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: ...

    def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: ...

    def __getattr__(self, name: str) -> Union[Tensor, 'Module']: ...

    def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: ...

70
71
    def __setstate__(self, state: Dict[str, Any]) -> None: ...

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
72
73
74
    # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
    # back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
    @overload
75
    def state_dict(self, destination: Mapping[str, Tensor], prefix: str = ..., keep_vars: bool = ...) -> Mapping[str, Tensor]: ...
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
76
77
78
79

    @overload
    def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> OrderedDict[str, Tensor]: ...

80
    def load_state_dict(self, state_dict: Union[Dict[str, Tensor], OrderedDict[str, Tensor]], strict: bool = ...) -> NamedTuple: ...
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

    def parameters(self, recurse: bool = ...) -> Iterator[Parameter]: ...

    def named_parameters(self, prefix: str = ..., recurse: bool = ...) -> Iterator[Tuple[str, Parameter]]: ...

    def buffers(self, recurse: bool = ...) -> Iterator[Tensor]: ...

    def named_buffers(self, prefix: str = ..., recurse: bool = ...) -> Iterator[Tuple[str, Tensor]]: ...

    def children(self) -> Iterator['Module']: ...

    def named_children(self) -> Iterator[Tuple[str, 'Module']]: ...

    def modules(self) -> Iterator['Module']: ...

    def named_modules(self, memo: Optional[Set['Module']] = ..., prefix: str = ...) -> Iterator[
        Tuple[str, 'Module']]: ...

    def train(self: T, mode: bool = ...) -> T: ...

    def eval(self: T) -> T: ...

    def zero_grad(self) -> None: ...

    def share_memory(self: T) -> T: ...

    def extra_repr(self) -> str: ...

109
    # This is added torchgpipe
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
110
    training: bool
111
112
113

    # Added by auto_wrap.py.
    wrapper_config: dict