Unverified Commit b0624680 authored by tripleMu's avatar tripleMu Committed by GitHub
Browse files

Add type hints for mmcv/parallel (#2031)



* Add typehints

* Fix

* Fix

* Update mmcv/parallel/distributed_deprecated.py
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Fix

* add type hints to scatter

add type hints to scatter

* fix ScatterInputs

* Update mmcv/parallel/_functions.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Fix

* refine type hints

* minor fix
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: default avatarHAOCHENYE <21724054@zju.edu.cn>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent 9110df94
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
from torch import Tensor
from torch.nn.parallel._functions import _get_stream
def scatter(input, devices, streams=None):
def scatter(input: Union[List, Tensor],
devices: List,
streams: Optional[List] = None) -> Union[List, Tensor]:
"""Scatters tensor across multiple GPUs."""
if streams is None:
streams = [None] * len(devices)
......@@ -15,7 +20,7 @@ def scatter(input, devices, streams=None):
[streams[i // chunk_size]]) for i in range(len(input))
]
return outputs
elif isinstance(input, torch.Tensor):
elif isinstance(input, Tensor):
output = input.contiguous()
# TODO: copy to a pinned buffer first (if copying from CPU)
stream = streams[0] if output.numel() > 0 else None
......@@ -28,14 +33,15 @@ def scatter(input, devices, streams=None):
raise Exception(f'Unknown type {type(input)}.')
def synchronize_stream(output, devices, streams):
def synchronize_stream(output: Union[List, Tensor], devices: List,
streams: List) -> None:
if isinstance(output, list):
chunk_size = len(output) // len(devices)
for i in range(len(devices)):
for j in range(chunk_size):
synchronize_stream(output[i * chunk_size + j], [devices[i]],
[streams[i]])
elif isinstance(output, torch.Tensor):
elif isinstance(output, Tensor):
if output.numel() != 0:
with torch.cuda.device(devices[0]):
main_stream = torch.cuda.current_stream()
......@@ -45,14 +51,14 @@ def synchronize_stream(output, devices, streams):
raise Exception(f'Unknown type {type(output)}.')
def get_input_device(input):
def get_input_device(input: Union[List, Tensor]) -> int:
if isinstance(input, list):
for item in input:
input_device = get_input_device(item)
if input_device != -1:
return input_device
return -1
elif isinstance(input, torch.Tensor):
elif isinstance(input, Tensor):
return input.get_device() if input.is_cuda else -1
else:
raise Exception(f'Unknown type {type(input)}.')
......@@ -61,7 +67,7 @@ def get_input_device(input):
class Scatter:
@staticmethod
def forward(target_gpus, input):
def forward(target_gpus: List[int], input: Union[List, Tensor]) -> tuple:
input_device = get_input_device(input)
streams = None
if input_device == -1 and target_gpus != [-1]:
......
......@@ -8,7 +8,7 @@ from torch.utils.data.dataloader import default_collate
from .data_container import DataContainer
def collate(batch, samples_per_gpu=1):
def collate(batch: Sequence, samples_per_gpu: int = 1):
"""Puts each data field into a tensor/DataContainer with outer dimension
batch size.
......
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Callable, Type, Union
import numpy as np
import torch
def assert_tensor_type(func):
def assert_tensor_type(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
......@@ -35,11 +37,11 @@ class DataContainer:
"""
def __init__(self,
data,
stack=False,
padding_value=0,
cpu_only=False,
pad_dims=2):
data: Union[torch.Tensor, np.ndarray],
stack: bool = False,
padding_value: int = 0,
cpu_only: bool = False,
pad_dims: int = 2):
self._data = data
self._cpu_only = cpu_only
self._stack = stack
......@@ -47,43 +49,43 @@ class DataContainer:
assert pad_dims in [None, 1, 2, 3]
self._pad_dims = pad_dims
def __repr__(self):
def __repr__(self) -> str:
return f'{self.__class__.__name__}({repr(self.data)})'
def __len__(self):
def __len__(self) -> int:
return len(self._data)
@property
def data(self):
def data(self) -> Union[torch.Tensor, np.ndarray]:
return self._data
@property
def datatype(self):
def datatype(self) -> Union[Type, str]:
if isinstance(self.data, torch.Tensor):
return self.data.type()
else:
return type(self.data)
@property
def cpu_only(self):
def cpu_only(self) -> bool:
return self._cpu_only
@property
def stack(self):
def stack(self) -> bool:
return self._stack
@property
def padding_value(self):
def padding_value(self) -> int:
return self._padding_value
@property
def pad_dims(self):
def pad_dims(self) -> int:
return self._pad_dims
@assert_tensor_type
def size(self, *args, **kwargs):
def size(self, *args, **kwargs) -> torch.Size:
return self.data.size(*args, **kwargs)
@assert_tensor_type
def dim(self):
def dim(self) -> int:
return self.data.dim()
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import chain
from typing import List, Tuple
from torch.nn.parallel import DataParallel
from .scatter_gather import scatter_kwargs
from .scatter_gather import ScatterInputs, scatter_kwargs
class MMDataParallel(DataParallel):
......@@ -31,7 +32,7 @@ class MMDataParallel(DataParallel):
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim=0, **kwargs):
def __init__(self, *args, dim: int = 0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
self.dim = dim
......@@ -49,7 +50,8 @@ class MMDataParallel(DataParallel):
else:
return super().forward(*inputs, **kwargs)
def scatter(self, inputs, kwargs, device_ids):
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
from mmcv import print_log
from mmcv.utils import TORCH_VERSION, digit_version
from .scatter_gather import scatter_kwargs
from .scatter_gather import ScatterInputs, scatter_kwargs
class MMDistributedDataParallel(DistributedDataParallel):
......@@ -18,12 +20,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
- It implement two APIs ``train_step()`` and ``val_step()``.
"""
def to_kwargs(self, inputs, kwargs, device_id):
def to_kwargs(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_id: int) -> Tuple[tuple, tuple]:
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
def scatter(self, inputs, kwargs, device_ids):
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
......@@ -7,17 +9,17 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors,
from mmcv.utils import TORCH_VERSION, digit_version
from .registry import MODULE_WRAPPERS
from .scatter_gather import scatter_kwargs
from .scatter_gather import ScatterInputs, scatter_kwargs
@MODULE_WRAPPERS.register_module()
class MMDistributedDataParallel(nn.Module):
def __init__(self,
module,
dim=0,
broadcast_buffers=True,
bucket_cap_mb=25):
module: nn.Module,
dim: int = 0,
broadcast_buffers: bool = True,
bucket_cap_mb: int = 25):
super().__init__()
self.module = module
self.dim = dim
......@@ -26,7 +28,8 @@ class MMDistributedDataParallel(nn.Module):
self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
self._sync_params()
def _dist_broadcast_coalesced(self, tensors, buffer_size):
def _dist_broadcast_coalesced(self, tensors: Sequence[torch.Tensor],
buffer_size: int) -> None:
for tensors in _take_tensors(tensors, buffer_size):
flat_tensors = _flatten_dense_tensors(tensors)
dist.broadcast(flat_tensors, 0)
......@@ -34,7 +37,7 @@ class MMDistributedDataParallel(nn.Module):
tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
tensor.copy_(synced)
def _sync_params(self):
def _sync_params(self) -> None:
module_states = list(self.module.state_dict().values())
if len(module_states) > 0:
self._dist_broadcast_coalesced(module_states,
......@@ -49,7 +52,8 @@ class MMDistributedDataParallel(nn.Module):
self._dist_broadcast_coalesced(buffers,
self.broadcast_bucket_size)
def scatter(self, inputs, kwargs, device_ids):
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from typing import List, Tuple, Union
from torch import Tensor
from torch.nn.parallel._functions import Scatter as OrigScatter
from ._functions import Scatter
from .data_container import DataContainer
ScatterInputs = Union[Tensor, DataContainer, tuple, list, dict]
def scatter(inputs, target_gpus, dim=0):
def scatter(inputs: ScatterInputs,
target_gpus: List[int],
dim: int = 0) -> list:
"""Scatter inputs to target gpus.
The only difference from original :func:`scatter` is to add support for
......@@ -14,7 +20,7 @@ def scatter(inputs, target_gpus, dim=0):
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
if isinstance(obj, Tensor):
if target_gpus != [-1]:
return OrigScatter.apply(target_gpus, None, dim, obj)
else:
......@@ -33,7 +39,7 @@ def scatter(inputs, target_gpus, dim=0):
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for targets in target_gpus]
return [obj for _ in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
......@@ -43,17 +49,22 @@ def scatter(inputs, target_gpus, dim=0):
try:
return scatter_map(inputs)
finally:
scatter_map = None
scatter_map = None # type: ignore
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
def scatter_kwargs(inputs: ScatterInputs,
kwargs: ScatterInputs,
target_gpus: List[int],
dim: int = 0) -> Tuple[tuple, tuple]:
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
length = len(kwargs) - len(inputs)
inputs.extend([() for _ in range(length)]) # type: ignore
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
length = len(inputs) - len(kwargs)
kwargs.extend([{} for _ in range(length)]) # type: ignore
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
from .registry import MODULE_WRAPPERS
def is_module_wrapper(module):
def is_module_wrapper(module: nn.Module) -> bool:
"""Check if a module is a module wrapper.
The following 3 modules in MMCV (and their subclasses) are regarded as
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment