batchnorm.pyi 1.28 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from ... import Tensor
from .. import Parameter
from .module import Module
from typing import Any, Optional


class _BatchNorm(Module):
    num_features: int = ...
    eps: float = ...
    momentum: float = ...
    affine: bool = ...
    track_running_stats: bool = ...
    weight: Parameter = ...
    bias: Parameter = ...

18
19
20
    # This field is used by fairscale.nn.misc.misc::patch_batchnorm
    _track_running_stats_backup: bool

Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#MODIFIED BY TORCHGPIPE
    running_mean: Tensor
    running_var: Tensor
    num_batches_tracked: Tensor

    def __init__(self, num_features: int, eps: float = ..., momentum: Optional[float] = ..., affine: bool = ...,
                 track_running_stats: bool = ...) -> None: ...
#END

    def reset_running_stats(self) -> None: ...

    def reset_parameters(self) -> None: ...


class BatchNorm1d(_BatchNorm): ...


class BatchNorm2d(_BatchNorm): ...


class BatchNorm3d(_BatchNorm): ...


class SyncBatchNorm(_BatchNorm):
    # TODO set process_group to the write type once torch.distributed is stubbed
    def __init__(self, num_features: int, eps: float = ..., momentum: float = ..., affine: bool = ...,
                 track_running_stats: bool = ..., process_group: Optional[Any] = ...) -> None: ...