data_parallel.pyi 980 Bytes
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from typing import Any, Optional, TypeVar
from .common_types import _devices_t, _device_t
from ..modules import Module
from ... import device, Tensor

T_co = TypeVar('T_co', covariant=True)
class DataParallel(Module[T_co]):
    module: Module = ...
    device_ids: _devices_t = ...
    dim: int = ...
    output_device: _device_t = ...
    src_device_obj: device = ...

    def __init__(self, module: Module[T_co], device_ids: Optional[_devices_t] = ..., output_device: Optional[_device_t] = ...,
                 dim: int = ...) -> None: ...

    def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ...
    def __call__(self, *inputs: Any, **kwargs: Any) -> T_co: ...


def data_parallel(module: Module, inputs: Any, device_ids: Optional[_devices_t] = ...,
                  output_device: Optional[_device_t] = ..., dim: int = ...,
                  module_kwargs: Optional[Any] = ...) -> Tensor: ...