Unverified Commit a848ecfd authored by Alex Yang's avatar Alex Yang Committed by GitHub
Browse files

[Fix] Fix type hint in file_client (#1942)

* [fix]:fix type hint in file_client and mmcv

* [fix]:fix type hint in tests files

* [fix]:fix type hint in tests files

* [fix]:fix pre-commit.yaml to igore test for mypy

* [fix]:fix pre-commit.yaml to igore test for mypy

* [fix]:fix precommit.yml

* [fix]:fix precommit.yml

* Update __init__.py

delete unused type-ignore comment
parent fb9af9f3
...@@ -46,6 +46,15 @@ repos: ...@@ -46,6 +46,15 @@ repos:
hooks: hooks:
- id: check-copyright - id: check-copyright
args: ["mmcv", "tests", "--excludes", "mmcv/ops"] args: ["mmcv", "tests", "--excludes", "mmcv/ops"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
hooks:
- id: mypy
exclude: |-
(?x)(
^test
| ^docs
)
# - repo: local # - repo: local
# hooks: # hooks:
# - id: clang-format # - id: clang-format
......
...@@ -5,9 +5,9 @@ import platform ...@@ -5,9 +5,9 @@ import platform
from .registry import PLUGIN_LAYERS from .registry import PLUGIN_LAYERS
if platform.system() == 'Windows': if platform.system() == 'Windows':
import regex as re import regex as re # type: ignore
else: else:
import re import re # type: ignore
def infer_abbr(class_type): def infer_abbr(class_type):
......
...@@ -8,7 +8,7 @@ import warnings ...@@ -8,7 +8,7 @@ import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Iterable, Iterator, Optional, Tuple, Union from typing import Any, Generator, Iterator, Optional, Tuple, Union
from urllib.request import urlopen from urllib.request import urlopen
import mmcv import mmcv
...@@ -298,7 +298,10 @@ class PetrelBackend(BaseStorageBackend): ...@@ -298,7 +298,10 @@ class PetrelBackend(BaseStorageBackend):
return '/'.join(formatted_paths) return '/'.join(formatted_paths)
@contextmanager @contextmanager
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: def get_local_path(
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath`` and return a temporary path. """Download a file from ``filepath`` and return a temporary path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
...@@ -646,7 +649,9 @@ class HardDiskBackend(BaseStorageBackend): ...@@ -646,7 +649,9 @@ class HardDiskBackend(BaseStorageBackend):
@contextmanager @contextmanager
def get_local_path( def get_local_path(
self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]: self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing.""" """Only for unified API and do nothing."""
yield filepath yield filepath
...@@ -715,7 +720,8 @@ class HTTPBackend(BaseStorageBackend): ...@@ -715,7 +720,8 @@ class HTTPBackend(BaseStorageBackend):
return value_buf.decode(encoding) return value_buf.decode(encoding)
@contextmanager @contextmanager
def get_local_path(self, filepath: str) -> Iterable[str]: def get_local_path(
self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``. """Download a file from ``filepath``.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
...@@ -789,15 +795,17 @@ class FileClient: ...@@ -789,15 +795,17 @@ class FileClient:
# backend appears in the collection, the singleton pattern is disabled for # backend appears in the collection, the singleton pattern is disabled for
# that backend, because if the singleton pattern is used, then the object # that backend, because if the singleton pattern is used, then the object
# returned will be the backend before overwriting # returned will be the backend before overwriting
_overridden_backends = set() _overridden_backends: set = set()
_prefix_to_backends = { _prefix_to_backends: dict = {
's3': PetrelBackend, 's3': PetrelBackend,
'http': HTTPBackend, 'http': HTTPBackend,
'https': HTTPBackend, 'https': HTTPBackend,
} }
_overridden_prefixes = set() _overridden_prefixes: set = set()
_instances = {} _instances: dict = {}
client: Any
def __new__(cls, backend=None, prefix=None, **kwargs): def __new__(cls, backend=None, prefix=None, **kwargs):
if backend is None and prefix is None: if backend is None and prefix is None:
...@@ -1107,7 +1115,10 @@ class FileClient: ...@@ -1107,7 +1115,10 @@ class FileClient:
return self.client.join_path(filepath, *filepaths) return self.client.join_path(filepath, *filepaths)
@contextmanager @contextmanager
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: def get_local_path(
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Download data from ``filepath`` and write the data to local path. """Download data from ``filepath`` and write the data to local path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
......
...@@ -5,7 +5,7 @@ try: ...@@ -5,7 +5,7 @@ try:
from yaml import CDumper as Dumper from yaml import CDumper as Dumper
from yaml import CLoader as Loader from yaml import CLoader as Loader
except ImportError: except ImportError:
from yaml import Loader, Dumper from yaml import Loader, Dumper # type: ignore
from .base import BaseFileHandler # isort:skip from .base import BaseFileHandler # isort:skip
......
...@@ -328,4 +328,4 @@ cast_pytorch_to_onnx = { ...@@ -328,4 +328,4 @@ cast_pytorch_to_onnx = {
# Global set to store the list of quantized operators in the network. # Global set to store the list of quantized operators in the network.
# This is currently only used in the conversion of quantized ops from PT # This is currently only used in the conversion of quantized ops from PT
# -> C2 via ONNX. # -> C2 via ONNX.
_quantized_ops = set() _quantized_ops: set = set()
...@@ -200,7 +200,7 @@ def _process_mmcls_checkpoint(checkpoint): ...@@ -200,7 +200,7 @@ def _process_mmcls_checkpoint(checkpoint):
class CheckpointLoader: class CheckpointLoader:
"""A general checkpoint loader to manage all schemes.""" """A general checkpoint loader to manage all schemes."""
_schemes = {} _schemes: dict = {}
@classmethod @classmethod
def _register_scheme(cls, prefixes, loader, force=False): def _register_scheme(cls, prefixes, loader, force=False):
......
...@@ -342,7 +342,7 @@ if (TORCH_VERSION != 'parrots' ...@@ -342,7 +342,7 @@ if (TORCH_VERSION != 'parrots'
else: else:
@HOOKS.register_module() @HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook): class Fp16OptimizerHook(OptimizerHook): # type: ignore
"""FP16 optimizer hook (mmcv's implementation). """FP16 optimizer hook (mmcv's implementation).
The steps of fp16 optimizer is as follows. The steps of fp16 optimizer is as follows.
...@@ -484,8 +484,8 @@ else: ...@@ -484,8 +484,8 @@ else:
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
@HOOKS.register_module() @HOOKS.register_module()
class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook, class GradientCumulativeFp16OptimizerHook( # type: ignore
Fp16OptimizerHook): GradientCumulativeOptimizerHook, Fp16OptimizerHook):
"""Fp16 optimizer Hook (using mmcv implementation) implements multi- """Fp16 optimizer Hook (using mmcv implementation) implements multi-
iters gradient cumulating.""" iters gradient cumulating."""
......
...@@ -22,9 +22,9 @@ from .misc import import_modules_from_strings ...@@ -22,9 +22,9 @@ from .misc import import_modules_from_strings
from .path import check_file_exist from .path import check_file_exist
if platform.system() == 'Windows': if platform.system() == 'Windows':
import regex as re import regex as re # type: ignore
else: else:
import re import re # type: ignore
BASE_KEY = '_base_' BASE_KEY = '_base_'
DELETE_KEY = '_delete_' DELETE_KEY = '_delete_'
......
...@@ -128,4 +128,4 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version( ...@@ -128,4 +128,4 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
'loaded in torch<1.5.') 'loaded in torch<1.5.')
raise error raise error
else: else:
from torch.utils.model_zoo import load_url # noqa: F401 from torch.utils.model_zoo import load_url # type: ignore # noqa: F401
...@@ -3,7 +3,7 @@ import logging ...@@ -3,7 +3,7 @@ import logging
import torch.distributed as dist import torch.distributed as dist
logger_initialized = {} logger_initialized: dict = {}
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
......
...@@ -103,7 +103,7 @@ _BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() ...@@ -103,7 +103,7 @@ _BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
class SyncBatchNorm(SyncBatchNorm_): class SyncBatchNorm(SyncBatchNorm_): # type: ignore
def _check_input_dim(self, input): def _check_input_dim(self, input):
if TORCH_VERSION == 'parrots': if TORCH_VERSION == 'parrots':
......
...@@ -41,7 +41,7 @@ def digit_version(version_str: str, length: int = 4): ...@@ -41,7 +41,7 @@ def digit_version(version_str: str, length: int = 4):
release.extend([val, 0]) release.extend([val, 0])
elif version.is_postrelease: elif version.is_postrelease:
release.extend([1, version.post]) release.extend([1, version.post]) # type: ignore
else: else:
release.extend([0, 0]) release.extend([0, 0])
return tuple(release) return tuple(release)
......
...@@ -22,9 +22,9 @@ def parse_version_info(version_str: str, length: int = 4) -> tuple: ...@@ -22,9 +22,9 @@ def parse_version_info(version_str: str, length: int = 4) -> tuple:
if len(release) < length: if len(release) < length:
release = release + [0] * (length - len(release)) release = release + [0] * (length - len(release))
if version.is_prerelease: if version.is_prerelease:
release.extend(list(version.pre)) release.extend(list(version.pre)) # type: ignore
elif version.is_postrelease: elif version.is_postrelease:
release.extend(list(version.post)) release.extend(list(version.post)) # type: ignore
else: else:
release.extend([0, 0]) release.extend([0, 0])
return tuple(release) return tuple(release)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment