Commit a7106c6b authored by yangwendi.vendor's avatar yangwendi.vendor Committed by zhouzaida
Browse files

[fix]:fix type hint in transforms

parent 59eaefeb
variables:
PYTORCH_IMAGE: registry.sensetime.com/openmmlab/pytorch18-cu102-mmengine:dev2
stages:
- linting
- test
- deploy
before_script:
- . /root/scripts/set_envs.sh
- echo $PATH
- gcc --version
- nvcc --version
- ruby --version
- python --version
- pip --version
- python -c "import torch; print(torch.__version__)"
linting:
image: $PYTORCH_IMAGE
stage: linting
script:
- pre-commit run --all-files
.test_template: &test_template_def
stage: test
script:
- echo "Start building..."
- MMCV_WITH_OPS=1 pip install -e .[all] -i https://pypi.tuna.tsinghua.edu.cn/simple/
- python -c "import mmcv; print(mmcv.__version__)"
- echo "Start testing..."
- coverage run --branch --source mmcv -m pytest tests/
- coverage report -m
test:pytorch1.8-cuda10:
image: $PYTORCH_IMAGE
<<: *test_template_def
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict, Optional
from typing import Dict, List, Optional, Tuple, Union
class BaseTransform(metaclass=ABCMeta):
def __call__(self, results: Dict) -> Optional[Dict]:
def __call__(self,
results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]:
return self.transform(results)
@abstractmethod
def transform(self, results: Dict) -> Optional[Dict]:
def transform(self,
results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]:
"""The transform function. All subclass of BaseTransform should
override this method.
......
# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings
from typing import Iterable, List, Optional, Sequence, Tuple, Union
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
......@@ -800,10 +800,12 @@ class MultiScaleFlipAug(BaseTransform):
else:
# if ``scales`` and ``scale_factor`` both be ``None``
if scale_factor is None:
self.scales = [1.]
self.scales = [1.] # type: ignore
elif isinstance(scale_factor, list):
self.scales = scale_factor # type: ignore
else:
self.scales = scale_factor if isinstance(
scale_factor, list) else [scale_factor]
self.scales = [scale_factor] # type: ignore
self.scale_key = 'scale_factor'
self.allow_flip = allow_flip
......@@ -816,7 +818,7 @@ class MultiScaleFlipAug(BaseTransform):
self.resize_cfg = resize_cfg.copy()
self.flip_cfg = flip_cfg
def transform(self, results: dict) -> Tuple[List, List]:
def transform(self, results: dict) -> Dict:
"""Apply test time augment transforms on results.
Args:
......@@ -848,12 +850,12 @@ class MultiScaleFlipAug(BaseTransform):
results['flip_direction'] = None
resize_flip = Compose(_resize_flip)
_results = results.copy()
_results = resize_flip(_results)
packed_results = self.transforms(_results)
_results = resize_flip(results.copy())
packed_results = self.transforms(_results) # type: ignore
inputs.append(packed_results['inputs'])
data_samples.append(packed_results['data_sample'])
inputs.append(packed_results['inputs']) # type: ignore
data_samples.append(
packed_results['data_sample']) # type: ignore
return dict(inputs=inputs, data_sample=data_samples)
def __repr__(self) -> str:
......@@ -1312,8 +1314,7 @@ class RandomResize(BaseTransform):
if isinstance(self.scale, tuple):
assert self.ratio_range is not None and len(self.ratio_range) == 2
scale: Tuple[int, int] = self._random_sample_ratio(
self.scale, self.ratio_range)
scale = self._random_sample_ratio(self.scale, self.ratio_range)
elif mmcv.is_list_of(self.scale, tuple):
scale = self._random_sample(self.scale)
else:
......
# Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Sequence
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import numpy as np
......@@ -25,7 +24,7 @@ IgnoreKey = object()
# Import nullcontext if python>=3.7, otherwise use a simple alternative
# implementation.
try:
from contextlib import nullcontext
from contextlib import nullcontext # type: ignore
except ImportError:
from contextlib import contextmanager
......@@ -55,10 +54,10 @@ class Compose(BaseTransform):
>>> ]
"""
def __init__(self, transforms: Union[Transform, List[Transform]]):
def __init__(self, transforms: Union[Transform, Sequence[Transform]]):
super().__init__()
if not isinstance(transforms, list):
if not isinstance(transforms, Sequence):
transforms = [transforms]
self.transforms: List = []
for transform in transforms:
......@@ -85,7 +84,7 @@ class Compose(BaseTransform):
dict or None: Transformed results.
"""
for t in self.transforms:
results = t(results)
results = t(results) # type: ignore
if results is None:
return None
return results
......@@ -331,7 +330,7 @@ class KeyMapper(BaseTransform):
# Apply remapping
outputs = self._map_output(outputs, self.remapping)
results.update(outputs)
results.update(outputs) # type: ignore
return results
......@@ -445,8 +444,7 @@ class TransformBroadcaster(KeyMapper):
def scatter_sequence(self, data: Dict) -> List[Dict]:
"""Scatter the broadcasting targets to a list of inputs of the wrapped
transforms.
"""
transforms."""
# infer split number from input
seq_len = 0
......@@ -458,7 +456,6 @@ class TransformBroadcaster(KeyMapper):
keys = data.keys()
for key in keys:
assert isinstance(data[key], Sequence)
if seq_len:
if len(data[key]) != seq_len:
......@@ -472,7 +469,7 @@ class TransformBroadcaster(KeyMapper):
assert seq_len > 0, 'Fail to get the number of broadcasting targets'
scatters = []
for i in range(seq_len):
for i in range(seq_len): # type: ignore
scatter = data.copy()
for key in keys:
scatter[key] = data[key][i]
......@@ -494,7 +491,7 @@ class TransformBroadcaster(KeyMapper):
# cacheable method of the transforms cache their outputs. Thus
# the random parameters will only generated once and shared
# by all data items.
ctx = cache_random_params
ctx = cache_random_params # type: ignore
else:
ctx = nullcontext # type: ignore
......@@ -602,13 +599,13 @@ class RandomApply(BaseTransform):
@cache_randomness
def random_apply(self) -> bool:
"""Return a random bool value indicating whether apply the transform.
"""
"""Return a random bool value indicating whether apply the
transform."""
return np.random.rand() < self.prob
def transform(self, results: Dict) -> Optional[Dict]:
"""Randomly apply the transform."""
if self.random_apply():
return self.transforms(results)
return self.transforms(results) # type: ignore
else:
return results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment