grad_scaler.pyi 778 Bytes
Newer Older
Jun Ru Anderson's avatar
Jun Ru Anderson committed
1
2
3
4
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from ...optim import Optimizer
from ... import device, Tensor
5
from typing import Dict, Any, Optional
Jun Ru Anderson's avatar
Jun Ru Anderson committed
6
7

class GradScaler(object):
8
9
10
11
    _scale: Optional[Tensor]
    _grows_tracker: Optional[Tensor]
    _per_optimizer_states: Dict[int, Dict[str, Any]]

12
13
    def __init__(self, init_scale: float, growth_factor: float, backoff_factor: float, growth_interval: int, enabled: bool): ...
    def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]: ...
14
15
    def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any): ...	
    def update(self, new_scale: Optional[float]=None): ...
16
    def unscale_(self, optimizer: Optimizer) -> None: ...