"vscode:/vscode.git/clone" did not exist on "f3adf4f6a7381b992c4c86e7d56737ca11bee29c"
Commit 31197c3e authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add type annotations to preserve return type

Summary:
X-link: https://github.com/facebookresearch/mobile-vision/pull/137

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/475

Reviewed By: YanjunChen329

Differential Revision: D42148563

fbshipit-source-id: 76b794988bda7f773a734838c79d2de087d7ce94
parent 07ddd262
...@@ -10,7 +10,7 @@ features, functions in this module share the same signatures as the ones from mo ...@@ -10,7 +10,7 @@ features, functions in this module share the same signatures as the ones from mo
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple, TypeVar
import detectron2.utils.comm as d2_comm import detectron2.utils.comm as d2_comm
import mobile_cv.torch.utils_pytorch.comm as mcv_comm import mobile_cv.torch.utils_pytorch.comm as mcv_comm
...@@ -32,6 +32,7 @@ from mobile_cv.torch.utils_pytorch.distributed_helper import ( ...@@ -32,6 +32,7 @@ from mobile_cv.torch.utils_pytorch.distributed_helper import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_RT = TypeVar("_RT") # return type
@dataclass @dataclass
...@@ -45,18 +46,18 @@ class D2GoSharedContext(BaseSharedContext): ...@@ -45,18 +46,18 @@ class D2GoSharedContext(BaseSharedContext):
# BC-compatible # BC-compatible
def get_local_rank(): def get_local_rank() -> int:
return mcv_comm.get_local_rank() return mcv_comm.get_local_rank()
# BC-compatible # BC-compatible
def get_num_processes_per_machine(): def get_num_processes_per_machine() -> int:
return mcv_comm.get_local_size() return mcv_comm.get_local_size()
# Modify mobile_cv's `default_distributed_worker` to also setup D2's comm module # Modify mobile_cv's `default_distributed_worker` to also setup D2's comm module
def distributed_worker( def distributed_worker(
main_func: Callable, main_func: Callable[..., _RT],
args: Tuple[Any, ...], args: Tuple[Any, ...],
kwargs: Dict[str, Any], kwargs: Dict[str, Any],
backend: str, backend: str,
...@@ -65,7 +66,7 @@ def distributed_worker( ...@@ -65,7 +66,7 @@ def distributed_worker(
return_save_file: Optional[str] = None, return_save_file: Optional[str] = None,
timeout: timedelta = DEFAULT_TIMEOUT, timeout: timedelta = DEFAULT_TIMEOUT,
shared_context: Optional[BaseSharedContext] = None, shared_context: Optional[BaseSharedContext] = None,
): ) -> _RT:
if shared_context: if shared_context:
set_shared_context( set_shared_context(
shared_context shared_context
...@@ -79,7 +80,7 @@ def distributed_worker( ...@@ -79,7 +80,7 @@ def distributed_worker(
def launch( def launch(
main_func: Callable, main_func: Callable[..., _RT],
num_processes_per_machine: int, num_processes_per_machine: int,
num_machines: int = 1, num_machines: int = 1,
machine_rank: int = 0, machine_rank: int = 0,
...@@ -91,7 +92,7 @@ def launch( ...@@ -91,7 +92,7 @@ def launch(
timeout: timedelta = DEFAULT_TIMEOUT, timeout: timedelta = DEFAULT_TIMEOUT,
args: Tuple[Any, ...] = (), args: Tuple[Any, ...] = (),
kwargs: Dict[str, Any] = None, kwargs: Dict[str, Any] = None,
) -> Dict[int, Any]: ) -> Dict[int, _RT]:
""" """
D2Go's specialized launch method, it does a few more things on top of mcv's launch: D2Go's specialized launch method, it does a few more things on top of mcv's launch:
- Automatically convert GPU to CPU if CUDA is not available. - Automatically convert GPU to CPU if CUDA is not available.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment