Unverified Commit 63f7796a authored by Tom Birch's avatar Tom Birch Committed by GitHub
Browse files

Multi-process pipe (#90)

Adds support for distributing pipeline stages across multiple processes (and therefore multiple machines)
* Adds a style argument to the Pipe constructor, defaulting to PipelineStyle.SingleProcess, but also supporting PipelineStyle.MultiProcess
* Added support for lazy construction of modules (see lazy_construction for an example)
* Added two implementations of inter-process communication: one based on rpc with globally visible queues, one based on send/recv
* Copied all the relevant tests from tests/pipe to tests/pipe_process and modified them to exercise PipelineStyle.MultiProcess
parent 49a198c9
...@@ -61,7 +61,11 @@ class Task: ...@@ -61,7 +61,11 @@ class Task:
""" """
def __init__( def __init__(
self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]], self,
stream: Optional[AbstractStream],
*,
compute: Callable[[], Batch],
finalize: Optional[Callable[[Batch], None]],
) -> None: ) -> None:
self.stream = stream self.stream = stream
self._compute = compute self._compute = compute
......
#!/bin/bash
set -e
for WORKERS in {1..5}; do
mpirun -n $WORKERS python -m pytest tests/nn/pipe_process
done
...@@ -34,7 +34,8 @@ from . import distributed ...@@ -34,7 +34,8 @@ from . import distributed
from . import version from . import version
#END #END
class dtype: ... class dtype:
is_floating_point: bool
class layout: ... class layout: ...
...@@ -325,7 +326,7 @@ class Tensor: ...@@ -325,7 +326,7 @@ class Tensor:
def cosh_(self) -> Tensor: ... def cosh_(self) -> Tensor: ...
def cpu(self) -> Tensor: ... def cpu(self) -> Tensor: ...
def cross(self, other: Tensor, dim: Optional[_int]=None) -> Tensor: ... def cross(self, other: Tensor, dim: Optional[_int]=None) -> Tensor: ...
def cuda(self, device: Optional[_device]=None, non_blocking: _bool=False) -> Tensor: ... def cuda(self, device: Union[_device, _int, str, None]=None, non_blocking: _bool=False) -> Tensor: ...
@overload @overload
def cumprod(self, dim: _int, *, dtype: Optional[_dtype]=None) -> Tensor: ... def cumprod(self, dim: _int, *, dtype: Optional[_dtype]=None) -> Tensor: ...
@overload @overload
...@@ -611,16 +612,16 @@ class Tensor: ...@@ -611,16 +612,16 @@ class Tensor:
def neg_(self) -> Tensor: ... def neg_(self) -> Tensor: ...
def nelement(self) -> _int: ... def nelement(self) -> _int: ...
@overload @overload
def new_empty(self, size: _size, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def new_empty(self, size: _size, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def new_empty(self, *size: _int, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def new_empty(self, *size: _int, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def new_full(self, size: _size, fill_value: Number, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def new_full(self, size: _size, fill_value: Number, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def new_ones(self, size: _size, dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False) -> Tensor: ... def new_ones(self, size: _size, dtype: Optional[_dtype]=None, device: Union[_device, _int, str, None]=None, requires_grad: _bool=False) -> Tensor: ...
def new_tensor(self, data: Any, dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False) -> Tensor: ... def new_tensor(self, data: Any, dtype: Optional[_dtype]=None, device: Union[_device, _int, str, None]=None, requires_grad: _bool=False) -> Tensor: ...
@overload @overload
def new_zeros(self, size: _size, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def new_zeros(self, size: _size, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def new_zeros(self, *size: _int, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def new_zeros(self, *size: _int, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def normal_(self, mean: _float=0, std: _float=1, *, generator: Generator=None) -> Tensor: ... def normal_(self, mean: _float=0, std: _float=1, *, generator: Generator=None) -> Tensor: ...
def numel(self) -> _int: ... def numel(self) -> _int: ...
def numpy(self) -> Any: ... def numpy(self) -> Any: ...
...@@ -814,7 +815,7 @@ class Tensor: ...@@ -814,7 +815,7 @@ class Tensor:
@overload @overload
def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ... def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...
@overload @overload
def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ... def to(self, device: Union[_device, _int, str, None]=None, dtype: Optional[_dtype]=None, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...
@overload @overload
def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ... def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...
def to_dense(self) -> Tensor: ... def to_dense(self) -> Tensor: ...
...@@ -962,7 +963,7 @@ def _convolution_nogroup(input: Tensor, weight: Tensor, bias: Optional[Tensor], ...@@ -962,7 +963,7 @@ def _convolution_nogroup(input: Tensor, weight: Tensor, bias: Optional[Tensor],
def _copy_from(self: Tensor, dst: Tensor, non_blocking: _bool=False) -> Tensor: ... def _copy_from(self: Tensor, dst: Tensor, non_blocking: _bool=False) -> Tensor: ...
def _ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: _size, target_lengths: _size, blank: _int=0, zero_infinity: _bool=False) -> Tuple[Tensor, Tensor]: ... def _ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: _size, target_lengths: _size, blank: _int=0, zero_infinity: _bool=False) -> Tuple[Tensor, Tensor]: ...
def _cudnn_ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: _size, target_lengths: _size, blank: _int, deterministic: _bool, zero_infinity: _bool) -> Tuple[Tensor, Tensor]: ... def _cudnn_ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: _size, target_lengths: _size, blank: _int, deterministic: _bool, zero_infinity: _bool) -> Tuple[Tensor, Tensor]: ...
def _cudnn_init_dropout_state(dropout: _float, train: _bool, dropout_seed: _int, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def _cudnn_init_dropout_state(dropout: _float, train: _bool, dropout_seed: _int, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def _cudnn_rnn(input: Tensor, weight: Union[Tuple[Tensor, ...], List[Tensor]], weight_stride0: _int, weight_buf: Optional[Tensor], hx: Tensor, cx: Optional[Tensor], mode: _int, hidden_size: _int, num_layers: _int, batch_first: _bool, dropout: _float, train: _bool, bidirectional: _bool, batch_sizes: _size, dropout_state: Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ... def _cudnn_rnn(input: Tensor, weight: Union[Tuple[Tensor, ...], List[Tensor]], weight_stride0: _int, weight_buf: Optional[Tensor], hx: Tensor, cx: Optional[Tensor], mode: _int, hidden_size: _int, num_layers: _int, batch_first: _bool, dropout: _float, train: _bool, bidirectional: _bool, batch_sizes: _size, dropout_state: Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ...
def _cudnn_rnn_flatten_weight(weight_arr: Union[Tuple[Tensor, ...], List[Tensor]], weight_stride0: _int, input_size: _int, mode: _int, hidden_size: _int, num_layers: _int, batch_first: _bool, bidirectional: _bool) -> Tensor: ... def _cudnn_rnn_flatten_weight(weight_arr: Union[Tuple[Tensor, ...], List[Tensor]], weight_stride0: _int, input_size: _int, mode: _int, hidden_size: _int, num_layers: _int, batch_first: _bool, bidirectional: _bool) -> Tensor: ...
def _cufft_clear_plan_cache(device_index: _int) -> None: ... def _cufft_clear_plan_cache(device_index: _int) -> None: ...
...@@ -974,13 +975,13 @@ def _dim_arange(like: Tensor, dim: _int) -> Tensor: ... ...@@ -974,13 +975,13 @@ def _dim_arange(like: Tensor, dim: _int) -> Tensor: ...
def _dirichlet_grad(x: Tensor, alpha: Tensor, total: Tensor) -> Tensor: ... def _dirichlet_grad(x: Tensor, alpha: Tensor, total: Tensor) -> Tensor: ...
def _embedding_bag(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool=False, mode: _int=0, sparse: _bool=False, per_sample_weights: Optional[Tensor]=None) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... def _embedding_bag(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool=False, mode: _int=0, sparse: _bool=False, per_sample_weights: Optional[Tensor]=None) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...
@overload @overload
def _empty_affine_quantized(size: _size, *, scale: _float=1, zero_point: _int=0, memory_format: Optional[memory_format]=contiguous_format, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def _empty_affine_quantized(size: _size, *, scale: _float=1, zero_point: _int=0, memory_format: Optional[memory_format]=contiguous_format, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def _empty_affine_quantized(*size: _int, scale: _float=1, zero_point: _int=0, memory_format: Optional[memory_format]=contiguous_format, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def _empty_affine_quantized(*size: _int, scale: _float=1, zero_point: _int=0, memory_format: Optional[memory_format]=contiguous_format, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def _empty_per_channel_affine_quantized(size: _size, *, scales: Tensor, zero_points: Tensor, axis: _int, memory_format: Optional[memory_format]=contiguous_format, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def _empty_per_channel_affine_quantized(size: _size, *, scales: Tensor, zero_points: Tensor, axis: _int, memory_format: Optional[memory_format]=contiguous_format, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def _empty_per_channel_affine_quantized(*size: _int, scales: Tensor, zero_points: Tensor, axis: _int, memory_format: Optional[memory_format]=contiguous_format, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def _empty_per_channel_affine_quantized(*size: _int, scales: Tensor, zero_points: Tensor, axis: _int, memory_format: Optional[memory_format]=contiguous_format, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def _fft_with_size(self: Tensor, signal_ndim: _int, complex_input: _bool, complex_output: _bool, inverse: _bool, checked_signal_sizes: _size, normalized: _bool, onesided: _bool, output_sizes: _size) -> Tensor: ... def _fft_with_size(self: Tensor, signal_ndim: _int, complex_input: _bool, complex_output: _bool, inverse: _bool, checked_signal_sizes: _size, normalized: _bool, onesided: _bool, output_sizes: _size) -> Tensor: ...
def _fused_dropout(self: Tensor, p: _float, generator: Generator=None) -> Tuple[Tensor, Tensor]: ... def _fused_dropout(self: Tensor, p: _float, generator: Generator=None) -> Tuple[Tensor, Tensor]: ...
def _has_compatible_shallow_copy_type(self: Tensor, from_: Tensor) -> _bool: ... def _has_compatible_shallow_copy_type(self: Tensor, from_: Tensor) -> _bool: ...
...@@ -1121,11 +1122,11 @@ def any(self: Tensor, dim: Union[str, None], keepdim: _bool=False, *, out: Optio ...@@ -1121,11 +1122,11 @@ def any(self: Tensor, dim: Union[str, None], keepdim: _bool=False, *, out: Optio
@overload @overload
def any(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def any(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
@overload @overload
def arange(start: Number, end: Number, step: Number, *, out: Optional[Tensor]=None, dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False) -> Tensor: ... def arange(start: Number, end: Number, step: Number, *, out: Optional[Tensor]=None, dtype: Optional[_dtype]=None, device: Union[_device, _int, str, None]=None, requires_grad: _bool=False) -> Tensor: ...
@overload @overload
def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False) -> Tensor: ... def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, dtype: Optional[_dtype]=None, device: Union[_device, _int, str, None]=None, requires_grad: _bool=False) -> Tensor: ...
@overload @overload
def arange(end: Number, *, out: Optional[Tensor]=None, dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False) -> Tensor: ... def arange(end: Number, *, out: Optional[Tensor]=None, dtype: Optional[_dtype]=None, device: Union[_device, _int, str, None]=None, requires_grad: _bool=False) -> Tensor: ...
def argmax(self: Tensor, dim: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ... def argmax(self: Tensor, dim: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ...
def argmin(self: Tensor, dim: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ... def argmin(self: Tensor, dim: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ...
@overload @overload
...@@ -1152,9 +1153,9 @@ def baddbmm(beta: Number, self: Tensor, batch1: Tensor, batch2: Tensor) -> Tenso ...@@ -1152,9 +1153,9 @@ def baddbmm(beta: Number, self: Tensor, batch1: Tensor, batch2: Tensor) -> Tenso
@overload @overload
def baddbmm(beta: Number, self: Tensor, batch1: Tensor, batch2: Tensor, *, out: Tensor) -> Tensor: ... def baddbmm(beta: Number, self: Tensor, batch1: Tensor, batch2: Tensor, *, out: Tensor) -> Tensor: ...
@overload @overload
def bartlett_window(window_length: _int, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def bartlett_window(window_length: _int, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def bartlett_window(window_length: _int, periodic: _bool, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def bartlett_window(window_length: _int, periodic: _bool, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def batch_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: _bool, momentum: _float, eps: _float, cudnn_enabled: _bool) -> Tensor: ... def batch_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: _bool, momentum: _float, eps: _float, cudnn_enabled: _bool) -> Tensor: ...
def batch_norm_backward_elemt(grad_out: Tensor, input: Tensor, mean: Tensor, invstd: Tensor, weight: Optional[Tensor], mean_dy: Tensor, mean_dy_xmu: Tensor) -> Tensor: ... def batch_norm_backward_elemt(grad_out: Tensor, input: Tensor, mean: Tensor, invstd: Tensor, weight: Optional[Tensor], mean_dy: Tensor, mean_dy_xmu: Tensor) -> Tensor: ...
def batch_norm_backward_reduce(grad_out: Tensor, input: Tensor, mean: Tensor, invstd: Tensor, weight: Optional[Tensor], input_g: _bool, weight_g: _bool, bias_g: _bool) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... def batch_norm_backward_reduce(grad_out: Tensor, input: Tensor, mean: Tensor, invstd: Tensor, weight: Optional[Tensor], input_g: _bool, weight_g: _bool, bias_g: _bool) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...
...@@ -1175,9 +1176,9 @@ def bitwise_xor(self: Tensor, other: Number, *, out: Optional[Tensor]=None) -> T ...@@ -1175,9 +1176,9 @@ def bitwise_xor(self: Tensor, other: Number, *, out: Optional[Tensor]=None) -> T
@overload @overload
def bitwise_xor(self: Tensor, other: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def bitwise_xor(self: Tensor, other: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
@overload @overload
def blackman_window(window_length: _int, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def blackman_window(window_length: _int, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def blackman_window(window_length: _int, periodic: _bool, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def blackman_window(window_length: _int, periodic: _bool, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def bmm(self: Tensor, mat2: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def bmm(self: Tensor, mat2: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def can_cast(from_: _dtype, to: _dtype) -> _bool: ... def can_cast(from_: _dtype, to: _dtype) -> _bool: ...
@overload @overload
...@@ -1251,18 +1252,18 @@ def embedding(weight: Tensor, indices: Tensor, padding_idx: _int=-1, scale_grad_ ...@@ -1251,18 +1252,18 @@ def embedding(weight: Tensor, indices: Tensor, padding_idx: _int=-1, scale_grad_
def embedding_bag(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool=False, mode: _int=0, sparse: _bool=False, per_sample_weights: Optional[Tensor]=None) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... def embedding_bag(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool=False, mode: _int=0, sparse: _bool=False, per_sample_weights: Optional[Tensor]=None) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...
def embedding_renorm_(self: Tensor, indices: Tensor, max_norm: _float, norm_type: _float) -> Tensor: ... def embedding_renorm_(self: Tensor, indices: Tensor, max_norm: _float, norm_type: _float) -> Tensor: ...
@overload @overload
def empty(size: _size, *, names: Optional[List[Union[str, None]]], memory_format: Optional[memory_format]=None, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def empty(size: _size, *, names: Optional[List[Union[str, None]]], memory_format: Optional[memory_format]=None, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int , str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def empty(*size: _int, names: Optional[List[Union[str, None]]], memory_format: Optional[memory_format]=None, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def empty(*size: _int, names: Optional[List[Union[str, None]]], memory_format: Optional[memory_format]=None, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def empty(size: _size, *, memory_format: Optional[memory_format]=None, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def empty(size: _size, *, memory_format: Optional[memory_format]=None, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def empty(*size: _int, memory_format: Optional[memory_format]=None, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def empty(*size: _int, memory_format: Optional[memory_format]=None, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def empty_like(self: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ... def empty_like(self: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
@overload @overload
def empty_like(self: Tensor, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def empty_like(self: Tensor, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def empty_strided(size: _size, stride: _size, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def empty_strided(size: _size, stride: _size, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def eq(self: Tensor, other: Number, *, out: Optional[Tensor]=None) -> Tensor: ... def eq(self: Tensor, other: Number, *, out: Optional[Tensor]=None) -> Tensor: ...
@overload @overload
...@@ -1278,9 +1279,9 @@ def exp_(self: Tensor) -> Tensor: ... ...@@ -1278,9 +1279,9 @@ def exp_(self: Tensor) -> Tensor: ...
def expm1(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def expm1(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def expm1_(self: Tensor) -> Tensor: ... def expm1_(self: Tensor) -> Tensor: ...
@overload @overload
def eye(n: _int, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def eye(n: _int, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def eye(n: _int, m: _int, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def eye(n: _int, m: _int, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def fake_quantize_per_channel_affine(self: Tensor, scale: Tensor, zero_point: Tensor, axis: _int, quant_min: _int, quant_max: _int) -> Tensor: ... def fake_quantize_per_channel_affine(self: Tensor, scale: Tensor, zero_point: Tensor, axis: _int, quant_min: _int, quant_max: _int) -> Tensor: ...
def fake_quantize_per_tensor_affine(self: Tensor, scale: _float, zero_point: _int, quant_min: _int, quant_max: _int) -> Tensor: ... def fake_quantize_per_tensor_affine(self: Tensor, scale: _float, zero_point: _int, quant_min: _int, quant_max: _int) -> Tensor: ...
def fbgemm_linear_fp16_weight(input: Tensor, packed_weight: Tensor, bias: Tensor) -> Tensor: ... def fbgemm_linear_fp16_weight(input: Tensor, packed_weight: Tensor, bias: Tensor) -> Tensor: ...
...@@ -1323,16 +1324,16 @@ def frac_(self: Tensor) -> Tensor: ... ...@@ -1323,16 +1324,16 @@ def frac_(self: Tensor) -> Tensor: ...
def frobenius_norm(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def frobenius_norm(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
@overload @overload
def frobenius_norm(self: Tensor, dim: Union[_int, _size], keepdim: _bool=False, *, out: Optional[Tensor]=None) -> Tensor: ... def frobenius_norm(self: Tensor, dim: Union[_int, _size], keepdim: _bool=False, *, out: Optional[Tensor]=None) -> Tensor: ...
def from_file(filename: str, shared: Optional[_bool]=None, size: Optional[_int]=0, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def from_file(filename: str, shared: Optional[_bool]=None, size: Optional[_int]=0, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def from_numpy(ndarray) -> Tensor: ... def from_numpy(ndarray) -> Tensor: ...
@overload @overload
def full(size: _size, fill_value: Number, *, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def full(size: _size, fill_value: Number, *, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def full(size: _size, fill_value: Number, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def full(size: _size, fill_value: Number, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def full_like(self: Tensor, fill_value: Number, *, memory_format: Optional[memory_format]=None) -> Tensor: ... def full_like(self: Tensor, fill_value: Number, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
@overload @overload
def full_like(self: Tensor, fill_value: Number, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def full_like(self: Tensor, fill_value: Number, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def gather(self: Tensor, dim: _int, index: Tensor, *, sparse_grad: _bool=False, out: Optional[Tensor]=None) -> Tensor: ... def gather(self: Tensor, dim: _int, index: Tensor, *, sparse_grad: _bool=False, out: Optional[Tensor]=None) -> Tensor: ...
@overload @overload
...@@ -1360,17 +1361,17 @@ def gt(self: Tensor, other: Number, *, out: Optional[Tensor]=None) -> Tensor: .. ...@@ -1360,17 +1361,17 @@ def gt(self: Tensor, other: Number, *, out: Optional[Tensor]=None) -> Tensor: ..
@overload @overload
def gt(self: Tensor, other: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def gt(self: Tensor, other: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
@overload @overload
def hamming_window(window_length: _int, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def hamming_window(window_length: _int, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def hamming_window(window_length: _int, periodic: _bool, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def hamming_window(window_length: _int, periodic: _bool, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def hamming_window(window_length: _int, periodic: _bool, alpha: _float, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def hamming_window(window_length: _int, periodic: _bool, alpha: _float, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def hamming_window(window_length: _int, periodic: _bool, alpha: _float, beta: _float, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def hamming_window(window_length: _int, periodic: _bool, alpha: _float, beta: _float, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def hann_window(window_length: _int, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def hann_window(window_length: _int, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def hann_window(window_length: _int, periodic: _bool, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def hann_window(window_length: _int, periodic: _bool, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def hardshrink(self: Tensor, lambd: Number=0.5) -> Tensor: ... def hardshrink(self: Tensor, lambd: Number=0.5) -> Tensor: ...
def histc(self: Tensor, bins: _int=100, min: Number=0, max: Number=0, *, out: Optional[Tensor]=None) -> Tensor: ... def histc(self: Tensor, bins: _int=100, min: Number=0, max: Number=0, *, out: Optional[Tensor]=None) -> Tensor: ...
def hspmm(mat1: Tensor, mat2: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def hspmm(mat1: Tensor, mat2: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
...@@ -1425,7 +1426,7 @@ def lerp(self: Tensor, end: Tensor, weight: Number, *, out: Optional[Tensor]=Non ...@@ -1425,7 +1426,7 @@ def lerp(self: Tensor, end: Tensor, weight: Number, *, out: Optional[Tensor]=Non
@overload @overload
def lerp(self: Tensor, end: Tensor, weight: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def lerp(self: Tensor, end: Tensor, weight: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def lgamma(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def lgamma(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def linspace(start: Number, end: Number, steps: _int=100, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def linspace(start: Number, end: Number, steps: _int=100, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def log(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def log(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def log10(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def log10(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def log10_(self: Tensor) -> Tensor: ... def log10_(self: Tensor) -> Tensor: ...
...@@ -1441,7 +1442,7 @@ def log_softmax(self: Tensor, dim: Union[str, None], *, dtype: Optional[_dtype]= ...@@ -1441,7 +1442,7 @@ def log_softmax(self: Tensor, dim: Union[str, None], *, dtype: Optional[_dtype]=
def logdet(self: Tensor) -> Tensor: ... def logdet(self: Tensor) -> Tensor: ...
def logical_not(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def logical_not(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def logical_xor(self: Tensor, other: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def logical_xor(self: Tensor, other: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def logspace(start: Number, end: Number, steps: _int=100, base: _float=10.0, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def logspace(start: Number, end: Number, steps: _int=100, base: _float=10.0, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def logsumexp(self: Tensor, dim: Union[_int, _size], keepdim: _bool=False, *, out: Optional[Tensor]=None) -> Tensor: ... def logsumexp(self: Tensor, dim: Union[_int, _size], keepdim: _bool=False, *, out: Optional[Tensor]=None) -> Tensor: ...
@overload @overload
...@@ -1540,24 +1541,24 @@ def normal(mean: _float, std: Tensor, *, generator: Generator=None, out: Optiona ...@@ -1540,24 +1541,24 @@ def normal(mean: _float, std: Tensor, *, generator: Generator=None, out: Optiona
@overload @overload
def normal(mean: Tensor, std: Tensor, *, generator: Generator=None, out: Optional[Tensor]=None) -> Tensor: ... def normal(mean: Tensor, std: Tensor, *, generator: Generator=None, out: Optional[Tensor]=None) -> Tensor: ...
@overload @overload
def normal(mean: _float, std: _float, size: _size, *, generator: Generator=None, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def normal(mean: _float, std: _float, size: _size, *, generator: Generator=None, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def nuclear_norm(self: Tensor, keepdim: _bool=False, *, out: Optional[Tensor]=None) -> Tensor: ... def nuclear_norm(self: Tensor, keepdim: _bool=False, *, out: Optional[Tensor]=None) -> Tensor: ...
@overload @overload
def nuclear_norm(self: Tensor, dim: Union[_int, _size], keepdim: _bool=False, *, out: Optional[Tensor]=None) -> Tensor: ... def nuclear_norm(self: Tensor, dim: Union[_int, _size], keepdim: _bool=False, *, out: Optional[Tensor]=None) -> Tensor: ...
def numel(self: Tensor) -> _int: ... def numel(self: Tensor) -> _int: ...
@overload @overload
def ones(size: _size, *, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def ones(size: _size, *, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def ones(*size: _int, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def ones(*size: _int, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def ones(size: _size, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def ones(size: _size, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def ones(*size: _int, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def ones(*size: _int, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def ones_like(self: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ... def ones_like(self: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
@overload @overload
def ones_like(self: Tensor, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def ones_like(self: Tensor, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def orgqr(self: Tensor, input2: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def orgqr(self: Tensor, input2: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def ormqr(self: Tensor, input2: Tensor, input3: Tensor, left: _bool=True, transpose: _bool=False, *, out: Optional[Tensor]=None) -> Tensor: ... def ormqr(self: Tensor, input2: Tensor, input3: Tensor, left: _bool=True, transpose: _bool=False, *, out: Optional[Tensor]=None) -> Tensor: ...
def pairwise_distance(x1: Tensor, x2: Tensor, p: _float=2, eps: _float=1e-06, keepdim: _bool=False) -> Tensor: ... def pairwise_distance(x1: Tensor, x2: Tensor, p: _float=2, eps: _float=1e-06, keepdim: _bool=False) -> Tensor: ...
...@@ -1603,62 +1604,62 @@ def quantized_max_pool2d(self: Tensor, kernel_size: Union[_int, _size], stride: ...@@ -1603,62 +1604,62 @@ def quantized_max_pool2d(self: Tensor, kernel_size: Union[_int, _size], stride:
def quantized_rnn_relu_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, packed_ih: Tensor, packed_hh: Tensor, col_offsets_ih: Tensor, col_offsets_hh: Tensor, scale_ih: Number, scale_hh: Number, zero_point_ih: Number, zero_point_hh: Number) -> Tensor: ... def quantized_rnn_relu_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, packed_ih: Tensor, packed_hh: Tensor, col_offsets_ih: Tensor, col_offsets_hh: Tensor, scale_ih: Number, scale_hh: Number, zero_point_ih: Number, zero_point_hh: Number) -> Tensor: ...
def quantized_rnn_tanh_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, packed_ih: Tensor, packed_hh: Tensor, col_offsets_ih: Tensor, col_offsets_hh: Tensor, scale_ih: Number, scale_hh: Number, zero_point_ih: Number, zero_point_hh: Number) -> Tensor: ... def quantized_rnn_tanh_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, packed_ih: Tensor, packed_hh: Tensor, col_offsets_ih: Tensor, col_offsets_hh: Tensor, scale_ih: Number, scale_hh: Number, zero_point_ih: Number, zero_point_hh: Number) -> Tensor: ...
@overload @overload
def rand(size: _size, *, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def rand(size: _size, *, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def rand(*size: _int, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def rand(*size: _int, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def rand(size: _size, *, generator: Generator, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def rand(size: _size, *, generator: Generator, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def rand(*size: _int, generator: Generator, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def rand(*size: _int, generator: Generator, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def rand(size: _size, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def rand(size: _size, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def rand(*size: _int, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def rand(*size: _int, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def rand(size: _size, *, generator: Generator, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def rand(size: _size, *, generator: Generator, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def rand(*size: _int, generator: Generator, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def rand(*size: _int, generator: Generator, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def rand_like(self: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ... def rand_like(self: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
@overload @overload
def rand_like(self: Tensor, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def rand_like(self: Tensor, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randint(low: _int, high: _int, size: _size, *, dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False) -> Tensor: ... def randint(low: _int, high: _int, size: _size, *, dtype: Optional[_dtype]=None, device: Union[_device, _int, str, None]=None, requires_grad: _bool=False) -> Tensor: ...
@overload @overload
def randint(high: _int, size: _size, *, dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False) -> Tensor: ... def randint(high: _int, size: _size, *, dtype: Optional[_dtype]=None, device: Union[_device, _int, str, None]=None, requires_grad: _bool=False) -> Tensor: ...
@overload @overload
def randint_like(self: Tensor, high: _int, *, memory_format: Optional[memory_format]=None) -> Tensor: ... def randint_like(self: Tensor, high: _int, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
@overload @overload
def randint_like(self: Tensor, low: _int, high: _int, *, memory_format: Optional[memory_format]=None) -> Tensor: ... def randint_like(self: Tensor, low: _int, high: _int, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
@overload @overload
def randint_like(self: Tensor, high: _int, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randint_like(self: Tensor, high: _int, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randint_like(self: Tensor, low: _int, high: _int, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randint_like(self: Tensor, low: _int, high: _int, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randn(size: _size, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randn(size: _size, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randn(*size: _int, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randn(*size: _int, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randn(size: _size, *, generator: Generator, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randn(size: _size, *, generator: Generator, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randn(*size: _int, generator: Generator, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randn(*size: _int, generator: Generator, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randn(size: _size, *, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randn(size: _size, *, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randn(*size: _int, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randn(*size: _int, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randn(size: _size, *, generator: Generator, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randn(size: _size, *, generator: Generator, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randn(*size: _int, generator: Generator, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randn(*size: _int, generator: Generator, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randn_like(self: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ... def randn_like(self: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
@overload @overload
def randn_like(self: Tensor, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randn_like(self: Tensor, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randperm(n: _int, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randperm(n: _int, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def randperm(n: _int, *, generator: Generator, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def randperm(n: _int, *, generator: Generator, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def range(start: Number, end: Number, step: Number=1, *, out: Optional[Tensor]=None, dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False) -> Tensor: ... def range(start: Number, end: Number, step: Number=1, *, out: Optional[Tensor]=None, dtype: Optional[_dtype]=None, device: Union[_device, _int, str, None]=None, requires_grad: _bool=False) -> Tensor: ...
def real(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def real(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def reciprocal(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def reciprocal(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def reciprocal_(self: Tensor) -> Tensor: ... def reciprocal_(self: Tensor) -> Tensor: ...
...@@ -1708,7 +1709,7 @@ def rsqrt_(self: Tensor) -> Tensor: ... ...@@ -1708,7 +1709,7 @@ def rsqrt_(self: Tensor) -> Tensor: ...
def rsub(self: Tensor, other: Tensor, *, alpha: Number=1) -> Tensor: ... def rsub(self: Tensor, other: Tensor, *, alpha: Number=1) -> Tensor: ...
@overload @overload
def rsub(self: Tensor, other: Number, alpha: Number=1) -> Tensor: ... def rsub(self: Tensor, other: Number, alpha: Number=1) -> Tensor: ...
def scalar_tensor(s: Number, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def scalar_tensor(s: Number, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def scatter(self: Tensor, dim: _int, index: Tensor, src: Tensor) -> Tensor: ... def scatter(self: Tensor, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...
@overload @overload
...@@ -1748,7 +1749,7 @@ def solve(self: Tensor, A: Tensor, *, out: Optional[Tensor]=None) -> Tuple[Tenso ...@@ -1748,7 +1749,7 @@ def solve(self: Tensor, A: Tensor, *, out: Optional[Tensor]=None) -> Tuple[Tenso
def sort(self: Tensor, dim: _int=-1, descending: _bool=False, *, out: Optional[Tensor]=None) -> Tuple[Tensor, Tensor]: ... def sort(self: Tensor, dim: _int=-1, descending: _bool=False, *, out: Optional[Tensor]=None) -> Tuple[Tensor, Tensor]: ...
@overload @overload
def sort(self: Tensor, dim: Union[str, None], descending: _bool=False, *, out: Optional[Tensor]=None) -> Tuple[Tensor, Tensor]: ... def sort(self: Tensor, dim: Union[str, None], descending: _bool=False, *, out: Optional[Tensor]=None) -> Tuple[Tensor, Tensor]: ...
def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List], size: Optional[_size]=None, *, dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List], size: Optional[_size]=None, *, dtype: Optional[_dtype]=None, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def split_with_sizes(self: Tensor, split_sizes: _size, dim: _int=0) -> Union[Tuple[Tensor, ...], List[Tensor]]: ... def split_with_sizes(self: Tensor, split_sizes: _size, dim: _int=0) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...
def sqrt(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def sqrt(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def sqrt_(self: Tensor) -> Tensor: ... def sqrt_(self: Tensor) -> Tensor: ...
...@@ -1799,7 +1800,7 @@ def tan(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... ...@@ -1799,7 +1800,7 @@ def tan(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def tan_(self: Tensor) -> Tensor: ... def tan_(self: Tensor) -> Tensor: ...
def tanh(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def tanh(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def tanh_(self: Tensor) -> Tensor: ... def tanh_(self: Tensor) -> Tensor: ...
def tensor(data: Any, dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False) -> Tensor: ... def tensor(data: Any, dtype: Optional[_dtype]=None, device: Union[_device, _int, str, None]=None, requires_grad: _bool=False) -> Tensor: ...
def threshold(self: Tensor, threshold: Number, value: Number, *, out: Optional[Tensor]=None) -> Tensor: ... def threshold(self: Tensor, threshold: Number, value: Number, *, out: Optional[Tensor]=None) -> Tensor: ...
def threshold_(self: Tensor, threshold: Number, value: Number) -> Tensor: ... def threshold_(self: Tensor, threshold: Number, value: Number) -> Tensor: ...
def topk(self: Tensor, k: _int, dim: _int=-1, largest: _bool=True, sorted: _bool=True, *, out: Optional[Tensor]=None) -> Tuple[Tensor, Tensor]: ... def topk(self: Tensor, k: _int, dim: _int=-1, largest: _bool=True, sorted: _bool=True, *, out: Optional[Tensor]=None) -> Tuple[Tensor, Tensor]: ...
...@@ -1814,9 +1815,9 @@ def trapz(y: Tensor, x: Tensor, *, dim: _int=-1) -> Tensor: ... ...@@ -1814,9 +1815,9 @@ def trapz(y: Tensor, x: Tensor, *, dim: _int=-1) -> Tensor: ...
def trapz(y: Tensor, *, dx: _float=1, dim: _int=-1) -> Tensor: ... def trapz(y: Tensor, *, dx: _float=1, dim: _int=-1) -> Tensor: ...
def triangular_solve(self: Tensor, A: Tensor, upper: _bool=True, transpose: _bool=False, unitriangular: _bool=False, *, out: Optional[Tensor]=None) -> Tuple[Tensor, Tensor]: ... def triangular_solve(self: Tensor, A: Tensor, upper: _bool=True, transpose: _bool=False, unitriangular: _bool=False, *, out: Optional[Tensor]=None) -> Tuple[Tensor, Tensor]: ...
def tril(self: Tensor, diagonal: _int=0, *, out: Optional[Tensor]=None) -> Tensor: ... def tril(self: Tensor, diagonal: _int=0, *, out: Optional[Tensor]=None) -> Tensor: ...
def tril_indices(row: _int, col: _int, offset: _int=0, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def tril_indices(row: _int, col: _int, offset: _int=0, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def triu(self: Tensor, diagonal: _int=0, *, out: Optional[Tensor]=None) -> Tensor: ... def triu(self: Tensor, diagonal: _int=0, *, out: Optional[Tensor]=None) -> Tensor: ...
def triu_indices(row: _int, col: _int, offset: _int=0, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def triu_indices(row: _int, col: _int, offset: _int=0, *, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
def trunc(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... def trunc(self: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def trunc_(self: Tensor) -> Tensor: ... def trunc_(self: Tensor) -> Tensor: ...
@overload @overload
...@@ -1843,17 +1844,17 @@ def where(condition: Tensor, self: Tensor, other: Tensor) -> Tensor: ... ...@@ -1843,17 +1844,17 @@ def where(condition: Tensor, self: Tensor, other: Tensor) -> Tensor: ...
def where(condition: Tensor) -> Union[Tuple[Tensor, ...], List[Tensor]]: ... def where(condition: Tensor) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...
def zero_(self: Tensor) -> Tensor: ... def zero_(self: Tensor) -> Tensor: ...
@overload @overload
def zeros(size: _size, *, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def zeros(size: _size, *, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def zeros(*size: _int, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def zeros(*size: _int, names: Optional[List[Union[str, None]]], out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def zeros(size: _size, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def zeros(size: _size, *, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def zeros(*size: _int, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def zeros(*size: _int, out: Optional[Tensor]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
@overload @overload
def zeros_like(self: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ... def zeros_like(self: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
@overload @overload
def zeros_like(self: Tensor, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ... def zeros_like(self: Tensor, *, memory_format: Optional[memory_format]=None, dtype: _dtype=None, layout: _layout=strided, device: Union[_device, _int, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
class DoubleStorage(Storage): ... class DoubleStorage(Storage): ...
class FloatStorage(Storage): ... class FloatStorage(Storage): ...
......
...@@ -69,8 +69,8 @@ class stream: ...@@ -69,8 +69,8 @@ class stream:
def __enter__(self) -> None: ... def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *args: Any) -> None: ...
def current_stream(device: Optional[_device_t]) -> Stream: ... def current_stream(device: Optional[_device_t] = None) -> Stream: ...
def default_stream(device: Optional[_device_t]) -> Stream: ... def default_stream(device: Optional[_device_t] = None) -> Stream: ...
#END #END
# #
default_generators: Tuple[Any] default_generators: Tuple[Any]
...@@ -4,8 +4,13 @@ from typing import Any, List, Union, Optional ...@@ -4,8 +4,13 @@ from typing import Any, List, Union, Optional
from torch import Tensor from torch import Tensor
import datetime import datetime
from . import rpc as rpc
class Backend: ... class Backend: ...
class ProcessGroup: ...
class ProcessGroup:
def size(self) -> int: ...
def rank(self) -> int: ...
class ReduceOp: class ReduceOp:
SUM: ReduceOp SUM: ReduceOp
...@@ -29,5 +34,12 @@ def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta ...@@ -29,5 +34,12 @@ def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def send(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def isend(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def recv(tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> int: ...
def irecv(tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> int: ...
class group(object): class group(object):
WORLD: Any WORLD: Any
class RRef: ...
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Union, Callable, Optional
class RRef:
...
class WorkerInfo:
...
def rpc_async(
to: Union[str, WorkerInfo],
func: Callable,
args: Optional[tuple] = None,
kwargs: Optional[dict] = None,
timeout=-1.0,
) -> None:
...
def rpc_sync(
to: Union[str, WorkerInfo],
func: Callable,
args: Optional[tuple] = None,
kwargs: Optional[dict] = None,
timeout=-1.0,
) -> None:
...
...@@ -33,7 +33,7 @@ class Module(Generic[T_co]): ...@@ -33,7 +33,7 @@ class Module(Generic[T_co]):
def apply(self: T, fn: Callable[['Module'], None]) -> T: ... def apply(self: T, fn: Callable[['Module'], None]) -> T: ...
def cuda(self: T, device: Optional[Union[int, device]] = ...) -> T: ... def cuda(self: T, device: Optional[Union[int, str, device]] = ...) -> T: ...
def cpu(self: T) -> T: ... def cpu(self: T) -> T: ...
......
...@@ -19,14 +19,20 @@ ...@@ -19,14 +19,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import inspect
import os import os
import random import random
import numpy import numpy
from packaging import version
import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import rpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
...@@ -39,7 +45,7 @@ class IdentityLayer(torch.nn.Module): ...@@ -39,7 +45,7 @@ class IdentityLayer(torch.nn.Module):
return self.weight return self.weight
def set_random_seed(seed): def set_random_seed(seed: int) -> None:
"""Set random seed for reproducability.""" """Set random seed for reproducability."""
random.seed(seed) random.seed(seed)
numpy.random.seed(seed) numpy.random.seed(seed)
...@@ -47,11 +53,40 @@ def set_random_seed(seed): ...@@ -47,11 +53,40 @@ def set_random_seed(seed):
model_parallel_cuda_manual_seed(seed) model_parallel_cuda_manual_seed(seed)
def dist_init(rank, world_size): def dist_init(rank, world_size, hostname=None):
os.environ["MASTER_ADDR"] = "localhost" if hostname is None:
os.environ["MASTER_PORT"] = "29501" hostname = "localhost"
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) print(f"dist init r={rank}, world={world_size}, host={hostname}")
torch.cuda.set_device(rank) os.environ["MASTER_ADDR"] = hostname
os.environ["MASTER_PORT"] = "10638"
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
if version.parse(torch.__version__).release >= (1, 6, 0):
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=init_method)
os.environ["MASTER_ADDR"] = hostname
os.environ["MASTER_PORT"] = "10639"
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
rpc.init_rpc(
f"Test{rank}",
rank=rank,
world_size=world_size,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=init_method),
)
else:
if world_size > 1:
rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size)
else:
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
if torch.cuda.is_available() and torch.cuda.device_count():
torch.cuda.set_device(rank % torch.cuda.device_count())
def get_worker_map():
return {rank: f"Test{rank}" for rank in range(dist.get_world_size())}
def get_world_sizes(): def get_world_sizes():
...@@ -59,6 +94,54 @@ def get_world_sizes(): ...@@ -59,6 +94,54 @@ def get_world_sizes():
return [x for x in [1, 2, 4, 8] if x <= limit] return [x for x in [1, 2, 4, 8] if x <= limit]
def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes()): def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes(), args=[]):
for world_size in world_sizes: for world_size in world_sizes:
mp.spawn(test_func, args=(world_size,), nprocs=world_size, join=True) mp.spawn(test_func, args=(world_size, *args), nprocs=world_size, join=True)
def helper(rank, world_size, func, args):
dist_init(rank, world_size)
initialize_model_parallel(1, world_size)
func(*args)
def torch_spawn(world_sizes=None):
if world_sizes is None:
world_sizes = get_world_sizes()
def fixer(func):
name = func.__name__
parameters = inspect.signature(func).parameters
if name.startswith("test"):
raise ValueError(
f"Tests marked with @torch_spawn (i.e. '{name}') should not have names beginning in 'test' as they will"
" be picked up by pytest without running the spawn wrapper"
)
@functools.wraps(func)
def replacement(*args, **kwargs):
assert args == tuple()
args = tuple(
kwargs[p] for p in parameters if p != "rank"
) # converting named parameters to positional parameters to pass to `spawn`
if "OMPI_COMM_WORLD_RANK" in os.environ:
torch.distributed.init_process_group("mpi")
world_size = torch.distributed.get_world_size()
initialize_model_parallel(1, world_size)
torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
if world_size in world_sizes:
func(*args)
else:
pytest.skip(f"requested world size doesn't match current world size")
else:
spawn_for_all_world_sizes(helper, world_sizes, (func, args))
caller_module = inspect.getmodule(inspect.currentframe().f_back)
setattr(caller_module, f"test_{name}", replacement)
return func
return fixer
...@@ -125,10 +125,11 @@ def test_adjacency(monkeypatch): ...@@ -125,10 +125,11 @@ def test_adjacency(monkeypatch):
for group in new_groups: for group in new_groups:
buckets[len(group)].append(group) buckets[len(group)].append(group)
assert sorted(list(buckets.keys())) == [model_parallel_size, data_parallel_size] assert sorted(list(buckets.keys())) == [model_parallel_size, pipeline_length, data_parallel_size]
assert len(buckets[model_parallel_size]) == pipeline_length * data_parallel_size assert len(buckets[model_parallel_size]) == pipeline_length * data_parallel_size
assert len(buckets[data_parallel_size]) == model_parallel_size * pipeline_length assert len(buckets[data_parallel_size]) == model_parallel_size * pipeline_length
assert len(buckets[pipeline_length]) == model_parallel_size * data_parallel_size
# Check that model_parallel groups are contiguous # Check that model_parallel groups are contiguous
for group in buckets[model_parallel_size]: for group in buckets[model_parallel_size]:
......
...@@ -19,15 +19,25 @@ ...@@ -19,15 +19,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import pytest
import torch import torch
from torch import nn from torch import nn
from torch.distributed import rpc
import torch.nn.init as init import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from fairscale.nn.model_parallel import initialize as mpu from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel import layers from fairscale.nn.model_parallel import layers
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import Pipe
from tests.nn.model_parallel.commons import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes from tests.nn.model_parallel.commons import (
dist_init,
get_world_sizes,
set_random_seed,
spawn_for_all_world_sizes,
torch_spawn,
)
def run_test_parallel_embedding(rank, model_parallel_size): def run_test_parallel_embedding(rank, model_parallel_size):
...@@ -297,33 +307,43 @@ def run_test_row_parallel_linear(rank, model_parallel_size): ...@@ -297,33 +307,43 @@ def run_test_row_parallel_linear(rank, model_parallel_size):
print(" >> passed the test :-)") print(" >> passed the test :-)")
def run_test_pipe(rank, model_parallel_size): def run_test_pipe(rank, world_size, skip_dist_init=False):
pipe_world_size = 2 pipe_world_size = 2
dist_init(rank, model_parallel_size)
mpu.initialize_model_parallel(model_parallel_size) if world_size == 1:
return
if not skip_dist_init:
dist_init(rank, world_size)
else:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29502"
rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size)
mpu.initialize_model_parallel(world_size / pipe_world_size, pipe_world_size)
model_parallel_size = mpu.get_model_parallel_world_size()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print( print(
"> testing Sequential + Pipe with model parallel size: {}, pipe: {}".format( "> testing Sequential + Pipe with model parallel size: {}, pipe: {}".format(
model_parallel_size, pipe_world_size model_parallel_size, pipe_world_size
) )
) )
model_parallel_size = mpu.get_model_parallel_world_size() chunk_size = 4
chunk_size = 8
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
input_size_coeff = 13 input_size_coeff = 3
input_size = input_size_coeff * model_parallel_size input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17 output_size_coeff = 7
output_size = output_size_coeff * model_parallel_size output_size = output_size_coeff * model_parallel_size
batch_size = 7 * chunk_size batch_size = 3 * chunk_size
target = torch.rand((batch_size, input_size), requires_grad=True).cuda()
print(f"target = {target}")
identity = IdentityLayer2D(batch_size, input_size).cuda() identity = IdentityLayer2D(batch_size, input_size).cuda()
pipeline_devices = mpu.get_pipeline_parallel_group() pipeline_devices = mpu.get_pipeline_parallel_group()
if pipe_world_size == 2 and len(pipeline_devices) == 1:
pipeline_devices.append(pipeline_devices[0] + model_parallel_size)
set_random_seed(seed) set_random_seed(seed)
model = nn.Sequential( model = nn.Sequential(
...@@ -331,33 +351,196 @@ def run_test_pipe(rank, model_parallel_size): ...@@ -331,33 +351,196 @@ def run_test_pipe(rank, model_parallel_size):
nn.ReLU(), nn.ReLU(),
layers.RowParallelLinear(output_size, input_size, keep_master_weight_for_test=True, bias=False).cuda(), layers.RowParallelLinear(output_size, input_size, keep_master_weight_for_test=True, bias=False).cuda(),
) )
set_random_seed(seed) set_random_seed(seed)
reference = nn.Sequential(
reference = [
nn.Linear(input_size, output_size, bias=False).cuda(), nn.Linear(input_size, output_size, bias=False).cuda(),
nn.ReLU(), nn.ReLU(),
nn.Linear(output_size, input_size, bias=False).cuda(), nn.Linear(output_size, input_size, bias=False).cuda(),
) ]
reference[0].weight.data = model[0].master_weight.cuda() print(f"setup {reference[0].weight.size()}, {model[0].weight.size()}, {(input_size, output_size)}")
reference[-1].weight.data = model[-1].master_weight.cuda() print(f"setup {reference[2].weight.size()}, {(output_size, input_size)}")
reference[0].weight = Parameter(model[0].get_master_weight().clone()).cuda()
reference[2].weight = Parameter(model[2].get_master_weight().clone()).cuda()
reference = nn.Sequential(*reference)
def grad_graph(depth, grad):
result = depth * " " + str(grad)
if grad:
for x in grad.next_functions:
result += "\n" + grad_graph(depth + 1, x[0])
return result
def check_weights(x, y, key: str, index=None):
for i in [2, 0]:
if index is not None and i != index:
continue
left = x[i].get_master_weight()
right = y[i].weight.data
if not torch.allclose(left, right, atol=1.0e-6) or index is not None:
print(f"check_weights {key}-{i}: left = {left}, \nright = {right}")
if not torch.equal(left, right):
print(f"check_weights NOT_EQUAL {key}-{i}: left = {left}, \nright = {right}")
assert torch.allclose(left, right, atol=1.0e-6)
def dump_opt_params(opt):
for i, group in enumerate(opt.param_groups):
for j, p in enumerate(group["params"]):
print(f"{torch.distributed.get_rank()}:param {(i,j)} = {p}")
print(f"{torch.distributed.get_rank()}:param.grad {(i,j)} = {p.grad}")
def forward_model(model_, target, step=False):
optimizer = torch.optim.SGD(model_.parameters(), lr=0.01, momentum=0.9)
optimizer.zero_grad()
model_.zero_grad()
output = model_(identity())
loss = nn.MSELoss()
model_.zero_grad()
if step:
loss(output, target).backward()
saved_weight_0 = model_[0].weight.data.clone()
saved_weight_2 = model_[2].weight.data.clone()
dump_opt_params(optimizer)
optimizer.step()
assert not torch.allclose(saved_weight_0, model_[0].weight.data, atol=1.0e-6)
assert not torch.allclose(saved_weight_2, model_[2].weight.data, atol=1.0e-6)
return output
output = forward_model(model, target)
reference_output = forward_model(reference, target)
loss_weight = torch.randn([batch_size, output_size]).cuda() error = reference_output.sub(output).max()
output = model(identity()) torch.distributed.barrier()
reference_output = reference(identity()) assert error < 1.0e-6
output = forward_model(model, target)
error = reference_output.sub(output).max()
torch.distributed.barrier()
assert error < 1.0e-6
output = forward_model(model, target)
error = reference_output.sub(output).max() error = reference_output.sub(output).max()
torch.distributed.barrier() torch.distributed.barrier()
assert error < 1.0e-6 assert error < 1.0e-6
check_weights(model, reference, "before")
saved_weight_0 = model[0].weight.data.clone()
saved_weight_2 = model[2].weight.data.clone()
output = forward_model(model, target, step=True)
error = reference_output.sub(output).max()
assert error < 1.0e-6
model[0].weight.data = saved_weight_0
model[2].weight.data = saved_weight_2
worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())}
if pipe_world_size == 2: if pipe_world_size == 2:
pipe_model = Pipe(model, [2, 1], devices=pipeline_devices, chunks=chunk_size) print(f"actually doing pipe stuff now")
assert torch.equal(saved_weight_0, model[0].weight.data)
assert torch.equal(saved_weight_2, model[2].weight.data)
pipe_model = Pipe(
model,
[2, 1],
style=Pipe.MultiProcess,
group=pipeline_devices,
worker_map=worker_map,
input_device=torch.cuda.current_device(),
chunks=chunk_size,
pipelined_backward=True,
).cuda()
torch.distributed.barrier() torch.distributed.barrier()
pipe_rank = torch.distributed.get_rank(group=mpu.get_pipeline_parallel_group())
print(f"pipe rank is {pipe_rank}")
if pipe_rank == 0:
assert torch.equal(saved_weight_0, pipe_model[0].weight.data)
else:
if not torch.equal(saved_weight_2, pipe_model[0].weight.data):
print(f"ne {pipe_rank}: left\n{saved_weight_2}\nright:\n{pipe_model[0].weight.data}")
assert torch.equal(saved_weight_2, pipe_model[0].weight.data)
optimizer = torch.optim.SGD(pipe_model.parameters(), lr=0.01, momentum=0.9)
optimizer.zero_grad()
if pipe_rank == 0:
assert torch.equal(saved_weight_0, pipe_model[0].weight.data)
print(f"runner {rank}:\n{pipe_model[0].weight.data}")
else:
assert torch.equal(saved_weight_2, pipe_model[0].weight.data)
print(f"runner {rank}:\n{pipe_model[0].weight.data}")
if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1:
check_weights(model, reference, "pre-pipe", index=2)
else:
check_weights(model, reference, "pre-pipe", index=0)
pipe_output = pipe_model(identity()) pipe_output = pipe_model(identity())
print(f"exited pipe for {rank}")
forward_model(reference, target, step=True)
print(f"pipe_output {rank} = {pipe_output}")
print(f"reference_output {rank} = {reference_output}")
error = reference_output.sub(pipe_output.cuda()).max()
torch.distributed.barrier() torch.distributed.barrier()
assert error < 1.0e-6
if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1:
error = reference_output.sub(pipe_output.cuda()).max()
if error >= 1.0e-6:
print(f"error bad {error}")
assert error < 1.0e-6
loss = nn.MSELoss()
failed = False
pipe_output.retain_grad()
with torch.autograd.profiler.profile() as prof:
try:
loss(pipe_output, target).backward()
except Exception as e:
failed = True
print(f"got {e} while doing backward, deadlock?")
if failed:
raise RuntimeError("failed somehow")
dump_opt_params(optimizer)
optimizer.step()
print(f"calling check_weights on master")
check_weights(model, reference, "pipe", index=2)
print(f"waiting for barrier on master, pid={os.getpid()}")
else:
print(f"calling backwards on slave, pid={os.getpid()}")
failed = False
with torch.autograd.profiler.profile() as prof:
try:
pipe_model.back_helper(pipe_output)
except Exception as e:
failed = True
print(f"got {e} while doing backward, deadlock?")
if failed:
raise RuntimeError("failed somehow")
dump_opt_params(optimizer)
print(f"calling step on slave")
optimizer.step()
print(f"calling check_weights on slave")
check_weights(model, reference, "pipe", index=0)
print(f"waiting for barrier on slave")
pipe_model.zero_grad()
torch.distributed.barrier()
pipe_output = pipe_model(identity())
updated_ref_output = forward_model(reference, target)
if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1:
error = updated_ref_output.sub(pipe_output.cuda()).max()
print(f"outputs are ref:\n{updated_ref_output}\npipe:\n{pipe_output}")
assert error < 1.0e-6
torch.distributed.barrier()
print(f"finished waiting for barrier on, pid={os.getpid()}")
print(f"really exited pipe for {rank}")
rpc.shutdown()
torch.distributed.destroy_process_group()
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
...@@ -376,11 +559,29 @@ def test_column_parallel(): ...@@ -376,11 +559,29 @@ def test_column_parallel():
spawn_for_all_world_sizes(run_test_column_parallel_linear) spawn_for_all_world_sizes(run_test_column_parallel_linear)
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="only works on mpi")
def test_row_parallel(): def test_row_parallel():
spawn_for_all_world_sizes(run_test_row_parallel_linear) spawn_for_all_world_sizes(run_test_row_parallel_linear)
def test_pipe(): @torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="only works on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def mpi_pipe():
mpu.destroy_model_parallel()
run_test_pipe(torch.distributed.get_rank(), torch.distributed.get_world_size(), skip_dist_init=True)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_pipe_layer():
world_sizes = [x for x in get_world_sizes() if x <= torch.cuda.device_count() / 2]
spawn_for_all_world_sizes(run_test_pipe, args=[False])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.skip(reason="potential deadlock in nccl with multiple processes using the same gpu")
def test_eight_pipe_layer():
world_sizes = [x for x in get_world_sizes() if x <= torch.cuda.device_count() / 2] world_sizes = [x for x in get_world_sizes() if x <= torch.cuda.device_count() / 2]
spawn_for_all_world_sizes(run_test_pipe, world_sizes) spawn_for_all_world_sizes(run_test_pipe, [8])
...@@ -27,7 +27,7 @@ from fairscale.nn.pipe.stream import default_stream ...@@ -27,7 +27,7 @@ from fairscale.nn.pipe.stream import default_stream
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_copy_returns_on_next_device(): def test_copy_returns_on_next_device():
portal = Portal(torch.rand(1), tensor_life=1) portal = Portal(torch.rand(1), tensor_life=1, index=0)
prev_stream = default_stream(torch.device("cpu")) prev_stream = default_stream(torch.device("cpu"))
next_stream = default_stream(torch.device("cuda")) next_stream = default_stream(torch.device("cuda"))
...@@ -52,7 +52,7 @@ def test_blue_orange(): ...@@ -52,7 +52,7 @@ def test_blue_orange():
# tensor1 ------------ Join -- Fork --- Mul --- Add -- output # tensor1 ------------ Join -- Fork --- Mul --- Add -- output
# #
main = tensor1 main = tensor1
portal = Portal(tensor2, tensor_life=2) portal = Portal(tensor2, tensor_life=2, index=0)
phony = portal.blue() phony = portal.blue()
main = join(main, phony) main = join(main, phony)
main, phony = fork(main) main, phony = fork(main)
...@@ -78,7 +78,7 @@ def test_blue_orange_not_requires_grad(): ...@@ -78,7 +78,7 @@ def test_blue_orange_not_requires_grad():
# tensor1 ------------ Join -- Fork --- Mul --- Add -- output # tensor1 ------------ Join -- Fork --- Mul --- Add -- output
# #
main = tensor1 main = tensor1
portal = Portal(tensor2, tensor_life=2) portal = Portal(tensor2, tensor_life=2, index=0)
phony = portal.blue() phony = portal.blue()
main = join(main, phony) main = join(main, phony)
main, phony = fork(main) main, phony = fork(main)
...@@ -93,7 +93,7 @@ def test_blue_orange_not_requires_grad(): ...@@ -93,7 +93,7 @@ def test_blue_orange_not_requires_grad():
def test_use_grad(): def test_use_grad():
tensor = torch.rand(1, requires_grad=True) tensor = torch.rand(1, requires_grad=True)
portal = Portal(tensor, tensor_life=1) portal = Portal(tensor, tensor_life=1, index=0)
portal.put_grad(tensor) portal.put_grad(tensor)
assert portal.use_grad() is tensor assert portal.use_grad() is tensor
...@@ -111,7 +111,7 @@ class TestTensorLife: ...@@ -111,7 +111,7 @@ class TestTensorLife:
def new_portal(tensor_life): def new_portal(tensor_life):
nonlocal portal nonlocal portal
tensor = torch.rand(1, requires_grad=True) tensor = torch.rand(1, requires_grad=True)
portal = Portal(tensor, tensor_life) portal = Portal(tensor, tensor_life, 0)
return portal, tensor return portal, tensor
yield new_portal yield new_portal
......
...@@ -72,9 +72,9 @@ def test_default_skip_tracker_by_data_parallel(): ...@@ -72,9 +72,9 @@ def test_default_skip_tracker_by_data_parallel():
def test_reuse_portal(): def test_reuse_portal():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
skip_tracker = SkipTrackerThroughPotals(skip_layout) skip_tracker = SkipTrackerThroughPotals(skip_layout, 0)
batch = Batch(torch.tensor([1.0])) batch = Batch(torch.tensor([1.0]), 0)
a = torch.tensor([2.0]) a = torch.tensor([2.0])
b = torch.tensor([2.0]) b = torch.tensor([2.0])
...@@ -87,9 +87,9 @@ def test_reuse_portal(): ...@@ -87,9 +87,9 @@ def test_reuse_portal():
def test_no_copy_no_portal(): def test_no_copy_no_portal():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)}) skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)})
skip_tracker = SkipTrackerThroughPotals(skip_layout) skip_tracker = SkipTrackerThroughPotals(skip_layout, 0)
batch = Batch(torch.tensor([1.0])) batch = Batch(torch.tensor([1.0]), 0)
a = torch.tensor([2.0]) a = torch.tensor([2.0])
b = torch.tensor([2.0]) b = torch.tensor([2.0])
...@@ -104,9 +104,9 @@ def test_no_copy_no_portal(): ...@@ -104,9 +104,9 @@ def test_no_copy_no_portal():
def test_tensor_life_without_checkpointing(): def test_tensor_life_without_checkpointing():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
skip_tracker = SkipTrackerThroughPotals(skip_layout) skip_tracker = SkipTrackerThroughPotals(skip_layout, 0)
batch = Batch(torch.tensor([1.0])) batch = Batch(torch.tensor([1.0]), 0)
tensor = torch.tensor([2.0]) tensor = torch.tensor([2.0])
skip_tracker.save(batch, None, "test", tensor) skip_tracker.save(batch, None, "test", tensor)
...@@ -118,9 +118,9 @@ def test_tensor_life_without_checkpointing(): ...@@ -118,9 +118,9 @@ def test_tensor_life_without_checkpointing():
def test_tensor_life_with_checkpointing(): def test_tensor_life_with_checkpointing():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
skip_tracker = SkipTrackerThroughPotals(skip_layout) skip_tracker = SkipTrackerThroughPotals(skip_layout, 0)
batch = Batch(torch.tensor([1.0])) batch = Batch(torch.tensor([1.0]), 0)
tensor = torch.tensor([2.0]) tensor = torch.tensor([2.0])
with enable_checkpointing(): with enable_checkpointing():
......
...@@ -24,7 +24,7 @@ import torch ...@@ -24,7 +24,7 @@ import torch
from torch import nn from torch import nn
import torch.cuda import torch.cuda
from fairscale.nn.pipe.checkpoint import Checkpointing, checkpoint, is_checkpointing, is_recomputing from fairscale.nn.pipe.checkpoint import Checkpointing, Function, TensorOrTensors, is_checkpointing, is_recomputing
from fairscale.nn.pipe.dependency import fork, join from fairscale.nn.pipe.dependency import fork, join
from fairscale.nn.pipe.microbatch import Batch from fairscale.nn.pipe.microbatch import Batch
...@@ -33,6 +33,20 @@ if torch.cuda.is_available(): ...@@ -33,6 +33,20 @@ if torch.cuda.is_available():
devices.append("cuda") devices.append("cuda")
def make_checkpoint(function: Function, input: TensorOrTensors, index: int) -> TensorOrTensors:
"""Makes a checkpoint with a simple interface like
:func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
:class:`Checkpoint` and :class:`Recompute` without boilerplate.
"""
batch = Batch(input, index)
chk = Checkpointing(function, batch)
batch = chk.checkpoint()
chk.recompute(batch)
return batch.tensor_or_tensors
@pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device): def test_serial_checkpoints(device):
# Copied from https://github.com/pytorch/pytorch/pull/18568. # Copied from https://github.com/pytorch/pytorch/pull/18568.
...@@ -57,12 +71,12 @@ def test_serial_checkpoints(device): ...@@ -57,12 +71,12 @@ def test_serial_checkpoints(device):
# Increase the next function sequence number. # Increase the next function sequence number.
_ = a + 1 + 2 + 3 + 4 + 5 _ = a + 1 + 2 + 3 + 4 + 5
a = checkpoint(partial(Log.apply, "a"), a) a = make_checkpoint(partial(Log.apply, "a"), a, 0)
a, phony = fork(a) a, phony = fork(a)
b = join(b, phony) b = join(b, phony)
b = checkpoint(partial(Log.apply, "b"), b) b = make_checkpoint(partial(Log.apply, "b"), b, 0)
c = torch.cat((a, b)) c = torch.cat((a, b))
...@@ -79,7 +93,7 @@ def test_serial_checkpoints(device): ...@@ -79,7 +93,7 @@ def test_serial_checkpoints(device):
def test_not_requires_grad(): def test_not_requires_grad():
x = Batch(torch.rand(1, requires_grad=False)) x = Batch(torch.rand(1, requires_grad=False), 0)
assert not x[0].requires_grad assert not x[0].requires_grad
def f(x): def f(x):
...@@ -102,7 +116,7 @@ def test_not_requires_grad_with_parameter(): ...@@ -102,7 +116,7 @@ def test_not_requires_grad_with_parameter():
def f(x): def f(x):
return x * a return x * a
y = checkpoint(f, x) y = make_checkpoint(f, x, 0)
y.backward() y.backward()
assert a.grad is not None assert a.grad is not None
...@@ -119,7 +133,7 @@ def test_random_in_checkpoint(device): ...@@ -119,7 +133,7 @@ def test_random_in_checkpoint(device):
torch.manual_seed(0) torch.manual_seed(0)
chk_x = torch.randn(3, 3, device=device, requires_grad=True) chk_x = torch.randn(3, 3, device=device, requires_grad=True)
chk_y = checkpoint(dropout, chk_x) chk_y = make_checkpoint(dropout, chk_x, 0)
chk_y.norm().backward() chk_y.norm().backward()
assert torch.allclose(x.grad, chk_x.grad) assert torch.allclose(x.grad, chk_x.grad)
...@@ -136,7 +150,7 @@ def test_detect_checkpointing_recomputing(): ...@@ -136,7 +150,7 @@ def test_detect_checkpointing_recomputing():
model = Detect() model = Detect()
input = torch.rand(1, requires_grad=True) input = torch.rand(1, requires_grad=True)
output = checkpoint(model, input) output = make_checkpoint(model, input, 0)
output.backward() output.backward()
assert logs == [(True, False), (False, True)] assert logs == [(True, False), (False, True)]
...@@ -167,5 +181,5 @@ def test_non_grad_output(): ...@@ -167,5 +181,5 @@ def test_non_grad_output():
model = ForkNonGrad() model = ForkNonGrad()
input = torch.rand(1, requires_grad=True) input = torch.rand(1, requires_grad=True)
output = checkpoint(model, input) output = make_checkpoint(model, input, 0)
output[0].backward() output[0].backward()
...@@ -26,7 +26,7 @@ from fairscale.nn.pipe.microbatch import Batch, check, gather, scatter ...@@ -26,7 +26,7 @@ from fairscale.nn.pipe.microbatch import Batch, check, gather, scatter
def test_batch_atomic(): def test_batch_atomic():
x = torch.tensor(42) x = torch.tensor(42)
b = Batch(x) b = Batch(x, 0)
assert b.atomic assert b.atomic
...@@ -41,7 +41,7 @@ def test_batch_atomic(): ...@@ -41,7 +41,7 @@ def test_batch_atomic():
def test_batch_non_atomic(): def test_batch_non_atomic():
x, y = torch.tensor(42), torch.tensor(21) x, y = torch.tensor(42), torch.tensor(21)
b = Batch((x, y)) b = Batch((x, y), 0)
assert not b.atomic assert not b.atomic
...@@ -56,8 +56,8 @@ def test_batch_non_atomic(): ...@@ -56,8 +56,8 @@ def test_batch_non_atomic():
def test_batch_call(): def test_batch_call():
a = Batch(torch.tensor(42)) a = Batch(torch.tensor(42), 0)
b = Batch((torch.tensor(42), torch.tensor(21))) b = Batch((torch.tensor(42), torch.tensor(21)), 0)
def f(x): def f(x):
return x return x
...@@ -67,8 +67,8 @@ def test_batch_call(): ...@@ -67,8 +67,8 @@ def test_batch_call():
def test_batch_setitem_by_index(): def test_batch_setitem_by_index():
a = Batch(torch.tensor(42)) a = Batch(torch.tensor(42), 0)
b = Batch((torch.tensor(42), torch.tensor(21))) b = Batch((torch.tensor(42), torch.tensor(21)), 0)
a[0] = torch.tensor(0) a[0] = torch.tensor(0)
b[0] = torch.tensor(0) b[0] = torch.tensor(0)
...@@ -83,8 +83,8 @@ def test_batch_setitem_by_index(): ...@@ -83,8 +83,8 @@ def test_batch_setitem_by_index():
def test_batch_setitem_by_slice(): def test_batch_setitem_by_slice():
a = Batch(torch.tensor(42)) a = Batch(torch.tensor(42), 0)
b = Batch((torch.tensor(42), torch.tensor(21))) b = Batch((torch.tensor(42), torch.tensor(21)), 0)
a[:] = (torch.tensor(0),) a[:] = (torch.tensor(0),)
b[:] = (torch.tensor(0),) b[:] = (torch.tensor(0),)
...@@ -115,7 +115,7 @@ def test_gather_tensors(): ...@@ -115,7 +115,7 @@ def test_gather_tensors():
a = torch.zeros(1, 1) a = torch.zeros(1, 1)
b = torch.zeros(1, 1) b = torch.zeros(1, 1)
ab = gather([Batch(a), Batch(b)]) ab = gather([Batch(a, 0), Batch(b, 0)])
assert ab.size() == (2, 1) assert ab.size() == (2, 1)
...@@ -124,7 +124,7 @@ def test_gather_tuples(): ...@@ -124,7 +124,7 @@ def test_gather_tuples():
a = (torch.zeros(1, 1), torch.zeros(2, 2)) a = (torch.zeros(1, 1), torch.zeros(2, 2))
b = (torch.zeros(1, 1), torch.zeros(2, 2)) b = (torch.zeros(1, 1), torch.zeros(2, 2))
ab = gather([Batch(a), Batch(b)]) ab = gather([Batch(a, 0), Batch(b, 0)])
assert isinstance(ab, tuple) assert isinstance(ab, tuple)
assert ab[0].size() == (2, 1) assert ab[0].size() == (2, 1)
......
...@@ -44,7 +44,7 @@ def test_join_running_workers(): ...@@ -44,7 +44,7 @@ def test_join_running_workers():
nonlocal count nonlocal count
time.sleep(0.1) time.sleep(0.1)
count += 1 count += 1
return Batch(()) return Batch((), 0)
with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues): with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues):
...@@ -70,7 +70,7 @@ def test_join_running_workers_with_exception(): ...@@ -70,7 +70,7 @@ def test_join_running_workers_with_exception():
nonlocal count nonlocal count
time.sleep(0.1) time.sleep(0.1)
count += 1 count += 1
return Batch(()) return Batch((), 0)
with pytest.raises(ExpectedException): with pytest.raises(ExpectedException):
with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues): with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues):
...@@ -96,7 +96,7 @@ def test_compute_multithreading(): ...@@ -96,7 +96,7 @@ def test_compute_multithreading():
def log_thread_id(): def log_thread_id():
thread_id = threading.current_thread().ident thread_id = threading.current_thread().ident
thread_ids.add(thread_id) thread_ids.add(thread_id)
return Batch(()) return Batch((), 0)
with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues): with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues):
for i in range(2): for i in range(2):
...@@ -112,7 +112,7 @@ def test_compute_success(): ...@@ -112,7 +112,7 @@ def test_compute_success():
"""Task.compute returns (True, (task, batch)) on success.""" """Task.compute returns (True, (task, batch)) on success."""
def _42(): def _42():
return Batch(torch.tensor(42)) return Batch(torch.tensor(42), 0)
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
t = Task(CPUStream, compute=_42, finalize=None) t = Task(CPUStream, compute=_42, finalize=None)
...@@ -145,7 +145,7 @@ def test_compute_exception(): ...@@ -145,7 +145,7 @@ def test_compute_exception():
def test_grad_mode(grad_mode): def test_grad_mode(grad_mode):
def detect_grad_enabled(): def detect_grad_enabled():
x = torch.rand(1, requires_grad=torch.is_grad_enabled()) x = torch.rand(1, requires_grad=torch.is_grad_enabled())
return Batch(x) return Batch(x, 0)
with torch.set_grad_enabled(grad_mode): with torch.set_grad_enabled(grad_mode):
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH.
# See also: https://docs.pytest.org/en/latest/goodpractices.html
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import pytest
import torch
from fairscale.nn.model_parallel import destroy_model_parallel
@pytest.fixture(autouse=True)
def manual_seed_zero():
torch.manual_seed(0)
def cuda_sleep_impl(seconds, cycles_per_ms):
torch.cuda._sleep(int(seconds * cycles_per_ms * 1000))
@pytest.fixture(scope="session")
def cuda_sleep():
# Warm-up CUDA.
torch.empty(1, device="cuda")
# From test/test_cuda.py in PyTorch.
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)
return functools.partial(cuda_sleep_impl, cycles_per_ms=cycles_per_ms)
def pytest_report_header():
return f"torch: {torch.__version__}"
def pytest_runtest_setup(item):
print(f"setup mpi function called")
def pytest_runtest_teardown(item):
if "OMPI_COMM_WORLD_RANK" in os.environ:
destroy_model_parallel()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from torch import nn
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange
from tests.nn.model_parallel.commons import get_worker_map, torch_spawn
@torch_spawn([3])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
def x1to3(balance, checkpoint):
torch.manual_seed(0)
@skippable(stash=["1to3"])
class Layer1(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def forward(self, input):
yield stash("1to3", input)
output = self.conv(input)
return output
class Layer2(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def forward(self, input):
output = self.conv(input)
return output
@skippable(pop=["1to3"])
class Layer3(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def forward(self, input):
skip_1to3 = yield pop("1to3")
output = self.conv(input) + skip_1to3
return output
model = nn.Sequential(Layer1(), Layer2(), Layer3())
model = Pipe(
model,
balance,
chunks=3,
checkpoint=checkpoint,
input_device=torch.cuda.current_device(),
style=Pipe.MultiProcess,
worker_map=get_worker_map(),
pipelined_backward=False,
).cuda()
input = torch.rand(30, 3, 224, 224, requires_grad=True).cuda()
input.retain_grad()
output = model(input)
if model.group.rank() == len(balance) - 1:
loss = output.mean()
loss.backward()
elif model.group.rank() < len(balance) - 1:
model.back_helper(output)
if model.group.rank() == len(balance) - 1:
# TODO(tom) the single-process test uses 2e-1 but for some reason
# mutli-process is more noisy, need to investigate why
assert torch.allclose(output.norm(), torch.tensor(1039.0).cuda(), atol=4e-1)
if model.group.rank() == 0:
assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053).cuda())
torch.distributed.barrier()
@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def none_skip():
@skippable(stash=["none"])
class Stash(nn.Module):
def forward(self, input):
yield stash("none", None)
return input
@skippable(pop=["none"])
class Pop(nn.Module):
def forward(self, input):
none = yield pop("none")
assert none is None
return input
model = nn.Sequential(Stash(), Pop())
model = Pipe(
model,
[1, 1],
style=Pipe.MultiProcess,
worker_map=get_worker_map(),
input_device=torch.cuda.current_device(),
chunks=5,
).cuda()
input = torch.rand(10, requires_grad=True).cuda()
input.retain_grad()
output = model(input)
def assert_grad_fn_is_not_portal(grad_fn, visited=set()):
if grad_fn in visited or grad_fn is None:
return
assert not isinstance(grad_fn, PortalBlue._backward_cls)
assert not isinstance(grad_fn, PortalCopy._backward_cls)
assert not isinstance(grad_fn, PortalOrange._backward_cls)
visited.add(grad_fn)
for next_grad_fn, _ in grad_fn.next_functions:
assert_grad_fn_is_not_portal(next_grad_fn, visited)
if model.group.rank() == 1:
assert_grad_fn_is_not_portal(output.grad_fn)
output.sum().backward()
else:
model.back_helper(output)
assert input.grad.mean().item() == 1
@torch_spawn([2])
def lazy_skippable_error():
"""Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception"""
@skippable(stash=["1to3"])
class Layer1(nn.Linear):
pass
@skippable(pop=["1to3"])
class Layer3(nn.Linear):
pass
model = [lambda: Layer1(10, 10), lambda: nn.Linear(10, 10), lambda: Layer3(10, 10)]
with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"):
Pipe(
model, [2, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(),
)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from torch import nn
from fairscale.nn.pipe import Pipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.tracker import current_skip_tracker
from tests.nn.model_parallel.commons import get_worker_map, torch_spawn
@skippable(stash=["skip"])
class Stash(nn.Module):
def forward(self, input):
yield stash("skip", input)
return input
@skippable(pop=["skip"])
class Pop(nn.Module):
def forward(self, input):
skip = yield pop("skip")
return input + skip
@torch_spawn([2])
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def delete_portal_tensor(train, checkpoint):
# Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
# +----------+ +------------+
#
# With checkpointing:
# +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+
def portal_tensor_life_is(tensor_life, skip_tracker=None):
if skip_tracker is None:
skip_tracker = current_skip_tracker()
# Get the current portal.
portal = list(skip_tracker.portals.values())[0]
if tensor_life == 0:
return portal.tensor_life == 0 and portal.tensor is None
else:
return portal.tensor_life == tensor_life and portal.tensor is not None
# Check the portal tensor after 'Stash'.
stash_ = Stash()
@stash_.register_forward_hook
def check_portal_tensor_after_stash(*_):
if is_checkpointing():
assert portal_tensor_life_is(2)
elif is_recomputing():
assert portal_tensor_life_is(0)
else:
assert portal_tensor_life_is(1)
pop_ = Pop()
@pop_.register_forward_hook
def check_portal_tensor_after_pop(*_):
if is_checkpointing():
assert portal_tensor_life_is(1)
elif is_recomputing():
assert portal_tensor_life_is(0)
else:
assert portal_tensor_life_is(0)
class NoPortalTensorAtBackward(nn.Module):
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.skip_tracker = current_skip_tracker()
return input.detach()
@staticmethod
def backward(ctx, grad):
assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker)
return grad
def forward(self, input):
return self.F.apply(input)
model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
model = Pipe(
model, balance=[2, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint,
)
input = torch.rand(10, requires_grad=True)
if train:
model.train()
output = model(input)
if model.group.rank() == 1:
output.norm().backward()
else:
model.back_helper(output)
else:
model.eval()
with torch.no_grad():
model(input)
torch.distributed.barrier()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment