Commit a7106c6b authored by yangwendi.vendor's avatar yangwendi.vendor Committed by zhouzaida
Browse files

[fix]:fix type hint in transforms

parent 59eaefeb
variables:
PYTORCH_IMAGE: registry.sensetime.com/openmmlab/pytorch18-cu102-mmengine:dev2
stages:
- linting
- test
- deploy
before_script:
- . /root/scripts/set_envs.sh
- echo $PATH
- gcc --version
- nvcc --version
- ruby --version
- python --version
- pip --version
- python -c "import torch; print(torch.__version__)"
linting:
image: $PYTORCH_IMAGE
stage: linting
script:
- pre-commit run --all-files
.test_template: &test_template_def
stage: test
script:
- echo "Start building..."
- MMCV_WITH_OPS=1 pip install -e .[all] -i https://pypi.tuna.tsinghua.edu.cn/simple/
- python -c "import mmcv; print(mmcv.__version__)"
- echo "Start testing..."
- coverage run --branch --source mmcv -m pytest tests/
- coverage report -m
test:pytorch1.8-cuda10:
image: $PYTORCH_IMAGE
<<: *test_template_def
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Dict, Optional from typing import Dict, List, Optional, Tuple, Union
class BaseTransform(metaclass=ABCMeta): class BaseTransform(metaclass=ABCMeta):
def __call__(self, results: Dict) -> Optional[Dict]: def __call__(self,
results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]:
return self.transform(results) return self.transform(results)
@abstractmethod @abstractmethod
def transform(self, results: Dict) -> Optional[Dict]: def transform(self,
results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]:
"""The transform function. All subclass of BaseTransform should """The transform function. All subclass of BaseTransform should
override this method. override this method.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random import random
import warnings import warnings
from typing import Iterable, List, Optional, Sequence, Tuple, Union from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
...@@ -800,10 +800,12 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -800,10 +800,12 @@ class MultiScaleFlipAug(BaseTransform):
else: else:
# if ``scales`` and ``scale_factor`` both be ``None`` # if ``scales`` and ``scale_factor`` both be ``None``
if scale_factor is None: if scale_factor is None:
self.scales = [1.] self.scales = [1.] # type: ignore
elif isinstance(scale_factor, list):
self.scales = scale_factor # type: ignore
else: else:
self.scales = scale_factor if isinstance( self.scales = [scale_factor] # type: ignore
scale_factor, list) else [scale_factor]
self.scale_key = 'scale_factor' self.scale_key = 'scale_factor'
self.allow_flip = allow_flip self.allow_flip = allow_flip
...@@ -816,7 +818,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -816,7 +818,7 @@ class MultiScaleFlipAug(BaseTransform):
self.resize_cfg = resize_cfg.copy() self.resize_cfg = resize_cfg.copy()
self.flip_cfg = flip_cfg self.flip_cfg = flip_cfg
def transform(self, results: dict) -> Tuple[List, List]: def transform(self, results: dict) -> Dict:
"""Apply test time augment transforms on results. """Apply test time augment transforms on results.
Args: Args:
...@@ -848,12 +850,12 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -848,12 +850,12 @@ class MultiScaleFlipAug(BaseTransform):
results['flip_direction'] = None results['flip_direction'] = None
resize_flip = Compose(_resize_flip) resize_flip = Compose(_resize_flip)
_results = results.copy() _results = resize_flip(results.copy())
_results = resize_flip(_results) packed_results = self.transforms(_results) # type: ignore
packed_results = self.transforms(_results)
inputs.append(packed_results['inputs']) inputs.append(packed_results['inputs']) # type: ignore
data_samples.append(packed_results['data_sample']) data_samples.append(
packed_results['data_sample']) # type: ignore
return dict(inputs=inputs, data_sample=data_samples) return dict(inputs=inputs, data_sample=data_samples)
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -1312,8 +1314,7 @@ class RandomResize(BaseTransform): ...@@ -1312,8 +1314,7 @@ class RandomResize(BaseTransform):
if isinstance(self.scale, tuple): if isinstance(self.scale, tuple):
assert self.ratio_range is not None and len(self.ratio_range) == 2 assert self.ratio_range is not None and len(self.ratio_range) == 2
scale: Tuple[int, int] = self._random_sample_ratio( scale = self._random_sample_ratio(self.scale, self.ratio_range)
self.scale, self.ratio_range)
elif mmcv.is_list_of(self.scale, tuple): elif mmcv.is_list_of(self.scale, tuple):
scale = self._random_sample(self.scale) scale = self._random_sample(self.scale)
else: else:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Sequence from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
...@@ -25,7 +24,7 @@ IgnoreKey = object() ...@@ -25,7 +24,7 @@ IgnoreKey = object()
# Import nullcontext if python>=3.7, otherwise use a simple alternative # Import nullcontext if python>=3.7, otherwise use a simple alternative
# implementation. # implementation.
try: try:
from contextlib import nullcontext from contextlib import nullcontext # type: ignore
except ImportError: except ImportError:
from contextlib import contextmanager from contextlib import contextmanager
...@@ -55,10 +54,10 @@ class Compose(BaseTransform): ...@@ -55,10 +54,10 @@ class Compose(BaseTransform):
>>> ] >>> ]
""" """
def __init__(self, transforms: Union[Transform, List[Transform]]): def __init__(self, transforms: Union[Transform, Sequence[Transform]]):
super().__init__() super().__init__()
if not isinstance(transforms, list): if not isinstance(transforms, Sequence):
transforms = [transforms] transforms = [transforms]
self.transforms: List = [] self.transforms: List = []
for transform in transforms: for transform in transforms:
...@@ -85,7 +84,7 @@ class Compose(BaseTransform): ...@@ -85,7 +84,7 @@ class Compose(BaseTransform):
dict or None: Transformed results. dict or None: Transformed results.
""" """
for t in self.transforms: for t in self.transforms:
results = t(results) results = t(results) # type: ignore
if results is None: if results is None:
return None return None
return results return results
...@@ -331,7 +330,7 @@ class KeyMapper(BaseTransform): ...@@ -331,7 +330,7 @@ class KeyMapper(BaseTransform):
# Apply remapping # Apply remapping
outputs = self._map_output(outputs, self.remapping) outputs = self._map_output(outputs, self.remapping)
results.update(outputs) results.update(outputs) # type: ignore
return results return results
...@@ -445,8 +444,7 @@ class TransformBroadcaster(KeyMapper): ...@@ -445,8 +444,7 @@ class TransformBroadcaster(KeyMapper):
def scatter_sequence(self, data: Dict) -> List[Dict]: def scatter_sequence(self, data: Dict) -> List[Dict]:
"""Scatter the broadcasting targets to a list of inputs of the wrapped """Scatter the broadcasting targets to a list of inputs of the wrapped
transforms. transforms."""
"""
# infer split number from input # infer split number from input
seq_len = 0 seq_len = 0
...@@ -458,7 +456,6 @@ class TransformBroadcaster(KeyMapper): ...@@ -458,7 +456,6 @@ class TransformBroadcaster(KeyMapper):
keys = data.keys() keys = data.keys()
for key in keys: for key in keys:
assert isinstance(data[key], Sequence) assert isinstance(data[key], Sequence)
if seq_len: if seq_len:
if len(data[key]) != seq_len: if len(data[key]) != seq_len:
...@@ -472,7 +469,7 @@ class TransformBroadcaster(KeyMapper): ...@@ -472,7 +469,7 @@ class TransformBroadcaster(KeyMapper):
assert seq_len > 0, 'Fail to get the number of broadcasting targets' assert seq_len > 0, 'Fail to get the number of broadcasting targets'
scatters = [] scatters = []
for i in range(seq_len): for i in range(seq_len): # type: ignore
scatter = data.copy() scatter = data.copy()
for key in keys: for key in keys:
scatter[key] = data[key][i] scatter[key] = data[key][i]
...@@ -494,7 +491,7 @@ class TransformBroadcaster(KeyMapper): ...@@ -494,7 +491,7 @@ class TransformBroadcaster(KeyMapper):
# cacheable method of the transforms cache their outputs. Thus # cacheable method of the transforms cache their outputs. Thus
# the random parameters will only generated once and shared # the random parameters will only generated once and shared
# by all data items. # by all data items.
ctx = cache_random_params ctx = cache_random_params # type: ignore
else: else:
ctx = nullcontext # type: ignore ctx = nullcontext # type: ignore
...@@ -602,13 +599,13 @@ class RandomApply(BaseTransform): ...@@ -602,13 +599,13 @@ class RandomApply(BaseTransform):
@cache_randomness @cache_randomness
def random_apply(self) -> bool: def random_apply(self) -> bool:
"""Return a random bool value indicating whether apply the transform. """Return a random bool value indicating whether apply the
""" transform."""
return np.random.rand() < self.prob return np.random.rand() < self.prob
def transform(self, results: Dict) -> Optional[Dict]: def transform(self, results: Dict) -> Optional[Dict]:
"""Randomly apply the transform.""" """Randomly apply the transform."""
if self.random_apply(): if self.random_apply():
return self.transforms(results) return self.transforms(results) # type: ignore
else: else:
return results return results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment