Unverified Commit b0624680 authored by tripleMu's avatar tripleMu Committed by GitHub
Browse files

Add type hints for mmcv/parallel (#2031)



* Add typehints

* Fix

* Fix

* Update mmcv/parallel/distributed_deprecated.py
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Fix

* add type hints to scatter

add type hints to scatter

* fix ScatterInputs

* Update mmcv/parallel/_functions.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Fix

* refine type hints

* minor fix
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: default avatarHAOCHENYE <21724054@zju.edu.cn>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent 9110df94
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch import torch
from torch import Tensor
from torch.nn.parallel._functions import _get_stream from torch.nn.parallel._functions import _get_stream
def scatter(input, devices, streams=None): def scatter(input: Union[List, Tensor],
devices: List,
streams: Optional[List] = None) -> Union[List, Tensor]:
"""Scatters tensor across multiple GPUs.""" """Scatters tensor across multiple GPUs."""
if streams is None: if streams is None:
streams = [None] * len(devices) streams = [None] * len(devices)
...@@ -15,7 +20,7 @@ def scatter(input, devices, streams=None): ...@@ -15,7 +20,7 @@ def scatter(input, devices, streams=None):
[streams[i // chunk_size]]) for i in range(len(input)) [streams[i // chunk_size]]) for i in range(len(input))
] ]
return outputs return outputs
elif isinstance(input, torch.Tensor): elif isinstance(input, Tensor):
output = input.contiguous() output = input.contiguous()
# TODO: copy to a pinned buffer first (if copying from CPU) # TODO: copy to a pinned buffer first (if copying from CPU)
stream = streams[0] if output.numel() > 0 else None stream = streams[0] if output.numel() > 0 else None
...@@ -28,14 +33,15 @@ def scatter(input, devices, streams=None): ...@@ -28,14 +33,15 @@ def scatter(input, devices, streams=None):
raise Exception(f'Unknown type {type(input)}.') raise Exception(f'Unknown type {type(input)}.')
def synchronize_stream(output, devices, streams): def synchronize_stream(output: Union[List, Tensor], devices: List,
streams: List) -> None:
if isinstance(output, list): if isinstance(output, list):
chunk_size = len(output) // len(devices) chunk_size = len(output) // len(devices)
for i in range(len(devices)): for i in range(len(devices)):
for j in range(chunk_size): for j in range(chunk_size):
synchronize_stream(output[i * chunk_size + j], [devices[i]], synchronize_stream(output[i * chunk_size + j], [devices[i]],
[streams[i]]) [streams[i]])
elif isinstance(output, torch.Tensor): elif isinstance(output, Tensor):
if output.numel() != 0: if output.numel() != 0:
with torch.cuda.device(devices[0]): with torch.cuda.device(devices[0]):
main_stream = torch.cuda.current_stream() main_stream = torch.cuda.current_stream()
...@@ -45,14 +51,14 @@ def synchronize_stream(output, devices, streams): ...@@ -45,14 +51,14 @@ def synchronize_stream(output, devices, streams):
raise Exception(f'Unknown type {type(output)}.') raise Exception(f'Unknown type {type(output)}.')
def get_input_device(input): def get_input_device(input: Union[List, Tensor]) -> int:
if isinstance(input, list): if isinstance(input, list):
for item in input: for item in input:
input_device = get_input_device(item) input_device = get_input_device(item)
if input_device != -1: if input_device != -1:
return input_device return input_device
return -1 return -1
elif isinstance(input, torch.Tensor): elif isinstance(input, Tensor):
return input.get_device() if input.is_cuda else -1 return input.get_device() if input.is_cuda else -1
else: else:
raise Exception(f'Unknown type {type(input)}.') raise Exception(f'Unknown type {type(input)}.')
...@@ -61,7 +67,7 @@ def get_input_device(input): ...@@ -61,7 +67,7 @@ def get_input_device(input):
class Scatter: class Scatter:
@staticmethod @staticmethod
def forward(target_gpus, input): def forward(target_gpus: List[int], input: Union[List, Tensor]) -> tuple:
input_device = get_input_device(input) input_device = get_input_device(input)
streams = None streams = None
if input_device == -1 and target_gpus != [-1]: if input_device == -1 and target_gpus != [-1]:
......
...@@ -8,7 +8,7 @@ from torch.utils.data.dataloader import default_collate ...@@ -8,7 +8,7 @@ from torch.utils.data.dataloader import default_collate
from .data_container import DataContainer from .data_container import DataContainer
def collate(batch, samples_per_gpu=1): def collate(batch: Sequence, samples_per_gpu: int = 1):
"""Puts each data field into a tensor/DataContainer with outer dimension """Puts each data field into a tensor/DataContainer with outer dimension
batch size. batch size.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import functools import functools
from typing import Callable, Type, Union
import numpy as np
import torch import torch
def assert_tensor_type(func): def assert_tensor_type(func: Callable) -> Callable:
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
...@@ -35,11 +37,11 @@ class DataContainer: ...@@ -35,11 +37,11 @@ class DataContainer:
""" """
def __init__(self, def __init__(self,
data, data: Union[torch.Tensor, np.ndarray],
stack=False, stack: bool = False,
padding_value=0, padding_value: int = 0,
cpu_only=False, cpu_only: bool = False,
pad_dims=2): pad_dims: int = 2):
self._data = data self._data = data
self._cpu_only = cpu_only self._cpu_only = cpu_only
self._stack = stack self._stack = stack
...@@ -47,43 +49,43 @@ class DataContainer: ...@@ -47,43 +49,43 @@ class DataContainer:
assert pad_dims in [None, 1, 2, 3] assert pad_dims in [None, 1, 2, 3]
self._pad_dims = pad_dims self._pad_dims = pad_dims
def __repr__(self): def __repr__(self) -> str:
return f'{self.__class__.__name__}({repr(self.data)})' return f'{self.__class__.__name__}({repr(self.data)})'
def __len__(self): def __len__(self) -> int:
return len(self._data) return len(self._data)
@property @property
def data(self): def data(self) -> Union[torch.Tensor, np.ndarray]:
return self._data return self._data
@property @property
def datatype(self): def datatype(self) -> Union[Type, str]:
if isinstance(self.data, torch.Tensor): if isinstance(self.data, torch.Tensor):
return self.data.type() return self.data.type()
else: else:
return type(self.data) return type(self.data)
@property @property
def cpu_only(self): def cpu_only(self) -> bool:
return self._cpu_only return self._cpu_only
@property @property
def stack(self): def stack(self) -> bool:
return self._stack return self._stack
@property @property
def padding_value(self): def padding_value(self) -> int:
return self._padding_value return self._padding_value
@property @property
def pad_dims(self): def pad_dims(self) -> int:
return self._pad_dims return self._pad_dims
@assert_tensor_type @assert_tensor_type
def size(self, *args, **kwargs): def size(self, *args, **kwargs) -> torch.Size:
return self.data.size(*args, **kwargs) return self.data.size(*args, **kwargs)
@assert_tensor_type @assert_tensor_type
def dim(self): def dim(self) -> int:
return self.data.dim() return self.data.dim()
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from itertools import chain from itertools import chain
from typing import List, Tuple
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
from .scatter_gather import scatter_kwargs from .scatter_gather import ScatterInputs, scatter_kwargs
class MMDataParallel(DataParallel): class MMDataParallel(DataParallel):
...@@ -31,7 +32,7 @@ class MMDataParallel(DataParallel): ...@@ -31,7 +32,7 @@ class MMDataParallel(DataParallel):
dim (int): Dimension used to scatter the data. Defaults to 0. dim (int): Dimension used to scatter the data. Defaults to 0.
""" """
def __init__(self, *args, dim=0, **kwargs): def __init__(self, *args, dim: int = 0, **kwargs):
super().__init__(*args, dim=dim, **kwargs) super().__init__(*args, dim=dim, **kwargs)
self.dim = dim self.dim = dim
...@@ -49,7 +50,8 @@ class MMDataParallel(DataParallel): ...@@ -49,7 +50,8 @@ class MMDataParallel(DataParallel):
else: else:
return super().forward(*inputs, **kwargs) return super().forward(*inputs, **kwargs)
def scatter(self, inputs, kwargs, device_ids): def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs): def train_step(self, *inputs, **kwargs):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch import torch
from torch.nn.parallel.distributed import (DistributedDataParallel, from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors) _find_tensors)
from mmcv import print_log from mmcv import print_log
from mmcv.utils import TORCH_VERSION, digit_version from mmcv.utils import TORCH_VERSION, digit_version
from .scatter_gather import scatter_kwargs from .scatter_gather import ScatterInputs, scatter_kwargs
class MMDistributedDataParallel(DistributedDataParallel): class MMDistributedDataParallel(DistributedDataParallel):
...@@ -18,12 +20,14 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -18,12 +20,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
- It implement two APIs ``train_step()`` and ``val_step()``. - It implement two APIs ``train_step()`` and ``val_step()``.
""" """
def to_kwargs(self, inputs, kwargs, device_id): def to_kwargs(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_id: int) -> Tuple[tuple, tuple]:
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8 # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id # to move all tensors to device_id
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim) return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
def scatter(self, inputs, kwargs, device_ids): def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs): def train_step(self, *inputs, **kwargs):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
...@@ -7,17 +9,17 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors, ...@@ -7,17 +9,17 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors,
from mmcv.utils import TORCH_VERSION, digit_version from mmcv.utils import TORCH_VERSION, digit_version
from .registry import MODULE_WRAPPERS from .registry import MODULE_WRAPPERS
from .scatter_gather import scatter_kwargs from .scatter_gather import ScatterInputs, scatter_kwargs
@MODULE_WRAPPERS.register_module() @MODULE_WRAPPERS.register_module()
class MMDistributedDataParallel(nn.Module): class MMDistributedDataParallel(nn.Module):
def __init__(self, def __init__(self,
module, module: nn.Module,
dim=0, dim: int = 0,
broadcast_buffers=True, broadcast_buffers: bool = True,
bucket_cap_mb=25): bucket_cap_mb: int = 25):
super().__init__() super().__init__()
self.module = module self.module = module
self.dim = dim self.dim = dim
...@@ -26,7 +28,8 @@ class MMDistributedDataParallel(nn.Module): ...@@ -26,7 +28,8 @@ class MMDistributedDataParallel(nn.Module):
self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024 self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
self._sync_params() self._sync_params()
def _dist_broadcast_coalesced(self, tensors, buffer_size): def _dist_broadcast_coalesced(self, tensors: Sequence[torch.Tensor],
buffer_size: int) -> None:
for tensors in _take_tensors(tensors, buffer_size): for tensors in _take_tensors(tensors, buffer_size):
flat_tensors = _flatten_dense_tensors(tensors) flat_tensors = _flatten_dense_tensors(tensors)
dist.broadcast(flat_tensors, 0) dist.broadcast(flat_tensors, 0)
...@@ -34,7 +37,7 @@ class MMDistributedDataParallel(nn.Module): ...@@ -34,7 +37,7 @@ class MMDistributedDataParallel(nn.Module):
tensors, _unflatten_dense_tensors(flat_tensors, tensors)): tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
tensor.copy_(synced) tensor.copy_(synced)
def _sync_params(self): def _sync_params(self) -> None:
module_states = list(self.module.state_dict().values()) module_states = list(self.module.state_dict().values())
if len(module_states) > 0: if len(module_states) > 0:
self._dist_broadcast_coalesced(module_states, self._dist_broadcast_coalesced(module_states,
...@@ -49,7 +52,8 @@ class MMDistributedDataParallel(nn.Module): ...@@ -49,7 +52,8 @@ class MMDistributedDataParallel(nn.Module):
self._dist_broadcast_coalesced(buffers, self._dist_broadcast_coalesced(buffers,
self.broadcast_bucket_size) self.broadcast_bucket_size)
def scatter(self, inputs, kwargs, device_ids): def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch from typing import List, Tuple, Union
from torch import Tensor
from torch.nn.parallel._functions import Scatter as OrigScatter from torch.nn.parallel._functions import Scatter as OrigScatter
from ._functions import Scatter from ._functions import Scatter
from .data_container import DataContainer from .data_container import DataContainer
ScatterInputs = Union[Tensor, DataContainer, tuple, list, dict]
def scatter(inputs, target_gpus, dim=0): def scatter(inputs: ScatterInputs,
target_gpus: List[int],
dim: int = 0) -> list:
"""Scatter inputs to target gpus. """Scatter inputs to target gpus.
The only difference from original :func:`scatter` is to add support for The only difference from original :func:`scatter` is to add support for
...@@ -14,7 +20,7 @@ def scatter(inputs, target_gpus, dim=0): ...@@ -14,7 +20,7 @@ def scatter(inputs, target_gpus, dim=0):
""" """
def scatter_map(obj): def scatter_map(obj):
if isinstance(obj, torch.Tensor): if isinstance(obj, Tensor):
if target_gpus != [-1]: if target_gpus != [-1]:
return OrigScatter.apply(target_gpus, None, dim, obj) return OrigScatter.apply(target_gpus, None, dim, obj)
else: else:
...@@ -33,7 +39,7 @@ def scatter(inputs, target_gpus, dim=0): ...@@ -33,7 +39,7 @@ def scatter(inputs, target_gpus, dim=0):
if isinstance(obj, dict) and len(obj) > 0: if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out return out
return [obj for targets in target_gpus] return [obj for _ in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell # After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references # has a reference to the actual function scatter_map, which has references
...@@ -43,17 +49,22 @@ def scatter(inputs, target_gpus, dim=0): ...@@ -43,17 +49,22 @@ def scatter(inputs, target_gpus, dim=0):
try: try:
return scatter_map(inputs) return scatter_map(inputs)
finally: finally:
scatter_map = None scatter_map = None # type: ignore
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): def scatter_kwargs(inputs: ScatterInputs,
kwargs: ScatterInputs,
target_gpus: List[int],
dim: int = 0) -> Tuple[tuple, tuple]:
"""Scatter with support for kwargs dictionary.""" """Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_gpus, dim) if inputs else [] inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs): if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) length = len(kwargs) - len(inputs)
inputs.extend([() for _ in range(length)]) # type: ignore
elif len(kwargs) < len(inputs): elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) length = len(inputs) - len(kwargs)
kwargs.extend([{} for _ in range(length)]) # type: ignore
inputs = tuple(inputs) inputs = tuple(inputs)
kwargs = tuple(kwargs) kwargs = tuple(kwargs)
return inputs, kwargs return inputs, kwargs
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
from .registry import MODULE_WRAPPERS from .registry import MODULE_WRAPPERS
def is_module_wrapper(module): def is_module_wrapper(module: nn.Module) -> bool:
"""Check if a module is a module wrapper. """Check if a module is a module wrapper.
The following 3 modules in MMCV (and their subclasses) are regarded as The following 3 modules in MMCV (and their subclasses) are regarded as
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment