__init__.pyi 3.01 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
4
5
6
7
8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from typing import Any, Callable, Union, Tuple, Sequence, Optional
from .. import Tensor
from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \
    set_grad_enabled as set_grad_enabled
from .profiler import record_function

Min Xu's avatar
Min Xu committed
9
10
11
12
# This is defined in CPP in PyTorch source
class ImperativeEngine:
    def queue_callback(self, callback: Callable[..., None]): ...

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
13
14
# TODO make Variable and Function more precise
class Variable:
Min Xu's avatar
Min Xu committed
15
    _execution_engine: ImperativeEngine
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

class Function:
    @staticmethod
    def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: ...
    @staticmethod
    def backward(ctx: Any, *grad_outputs: Any) -> Any: ...

#MODIFIED BY TORCHGPIPE
    @staticmethod
    def apply(*args: Any, **kwargs: Any) -> Any: ...
#END

class NestedIOFunction(Function):
    # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
    # superclass (Function) but are instance methods here, which mypy reports as incomptabile.
    def backward(self, *gradients: Any) -> Any: ...  # type: ignore
    def forward(self, *args: Any) -> tuple: ...  # type: ignore
    def save_for_backward(self, *args: Any) -> None:...
    def mark_dirty(self, *args: Any, **kwargs: Any) -> None:...
    def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None: ...
    def forward_extended(self, *input: Any) -> None:...
    def backward_extended(self, *grad_output: Any) -> None: ...

# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
# the '...' first argument of Callabe can be replaced with VarArg(Tensor).
# For now, we permit any input.
def gradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., raise_exception: bool=..., check_sparse_nnz: bool=...) -> bool: ...
def gradgradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., gen_non_contig_grad_outputs: bool=..., raise_exception: bool=...) -> bool: ...

class detect_anomaly:
    def __enter__(self) -> None: ...
    def __exit__(self, *args: Any) -> bool: ...

class set_detect_anomaly:
    def __init__(self, mode: bool) -> None: ...
    def __enter__(self) -> None:...
    def __exit__(self, *args: Any) -> bool: ...

_TensorOrTensors = Union[Tensor, Sequence[Tensor]]
def backward(tensors: _TensorOrTensors, grad_tensors: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=...) -> None: ...
Tom Birch's avatar
Tom Birch committed
57
58
def grad(outputs: _TensorOrTensors, inputs: _TensorOrTensors, grad_outputs: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=..., only_inputs: bool=..., allow_unused: bool=...) -> Tuple[Tensor, ...]: ...
def _is_checkpoint_valid() -> bool: ...