adaptive.pyi 1.01 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
4
5
6
7
8
9
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from ... import Tensor
from .module import Module
from .linear import Linear
from collections import namedtuple
from typing import List, Sequence
from .container import ModuleList

10
_ASMoutput = namedtuple('_ASMoutput', ['output', 'loss'])
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


class AdaptiveLogSoftmaxWithLoss(Module):
    in_features: int = ...
    n_classes: int = ...
    cutoffs: List[int] = ...
    div_value: float = ...
    head_bias: bool = ...
    head: Linear = ...
    tail: ModuleList = ...

    def __init__(self, in_features: int, n_classes: int, cutoffs: Sequence[int], div_value: float = ...,
                 head_bias: bool = ...) -> None: ...

    def reset_parameters(self) -> None: ...

    def forward(self, input: Tensor, target: Tensor) -> _ASMoutput: ...  # type: ignore

    def __call__(self, input: Tensor, target: Tensor) -> _ASMoutput: ...  # type: ignore

    def log_prob(self, input: Tensor) -> List[float]: ...

    def predict(self, input: Tensor) -> Tensor: ...