# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from typing import Any, Callable, Union, Tuple, Sequence, Optional from .. import Tensor from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \ set_grad_enabled as set_grad_enabled from .profiler import record_function # TODO make Variable and Function more precise class Variable: ... class Function: @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: ... @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: ... #MODIFIED BY TORCHGPIPE @staticmethod def apply(*args: Any, **kwargs: Any) -> Any: ... #END class NestedIOFunction(Function): # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the # superclass (Function) but are instance methods here, which mypy reports as incomptabile. def backward(self, *gradients: Any) -> Any: ... # type: ignore def forward(self, *args: Any) -> tuple: ... # type: ignore def save_for_backward(self, *args: Any) -> None:... def mark_dirty(self, *args: Any, **kwargs: Any) -> None:... def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None: ... def forward_extended(self, *input: Any) -> None:... def backward_extended(self, *grad_output: Any) -> None: ... # 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment. # If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted, # the '...' first argument of Callabe can be replaced with VarArg(Tensor). # For now, we permit any input. def gradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., raise_exception: bool=..., check_sparse_nnz: bool=...) -> bool: ... def gradgradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., gen_non_contig_grad_outputs: bool=..., raise_exception: bool=...) -> bool: ... class detect_anomaly: def __enter__(self) -> None: ... def __exit__(self, *args: Any) -> bool: ... class set_detect_anomaly: def __init__(self, mode: bool) -> None: ... def __enter__(self) -> None:... def __exit__(self, *args: Any) -> bool: ... _TensorOrTensors = Union[Tensor, Sequence[Tensor]] def backward(tensors: _TensorOrTensors, grad_tensors: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=...) -> None: ... def grad(outputs: _TensorOrTensors, inputs: _TensorOrTensors, grad_outputs: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=..., only_inputs: bool=..., allow_unused: bool=...) -> Tuple[Tensor, ...]: ... def _is_checkpoint_valid() -> bool: ...