Unverified Commit 4a874a6b authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[cleanup] fix pre-commit mypy issues (#87)

parent d16e9f61
...@@ -5,7 +5,7 @@ import argparse ...@@ -5,7 +5,7 @@ import argparse
import math import math
import os import os
import time import time
from typing import Any, List, Union, cast from typing import Any, List, cast
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -62,7 +62,7 @@ def train( ...@@ -62,7 +62,7 @@ def train(
torch.cuda.reset_peak_memory_stats(rank) torch.cuda.reset_peak_memory_stats(rank)
# Shard the optimizer # Shard the optimizer
optimizer: Union[OSS, OPTIM] = OSS( optimizer: torch.optim.Optimizer = OSS(
params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9 params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9
) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9) ) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
......
# type: ignore
# Configuration file for the Sphinx documentation builder. # Configuration file for the Sphinx documentation builder.
# #
# This file only contains a selection of the most common options. For a full # This file only contains a selection of the most common options. For a full
...@@ -73,10 +74,7 @@ html_static_path = ["_static"] ...@@ -73,10 +74,7 @@ html_static_path = ["_static"]
def setup(app): def setup(app):
app.add_config_value( app.add_config_value(
"recommonmark_config", "recommonmark_config",
{ {"url_resolver": lambda url: github_doc_root + url, "auto_toc_tree_section": "Contents"},
"url_resolver": lambda url: github_doc_root + url,
"auto_toc_tree_section": "Contents",
},
True, True,
) )
app.add_transform(AutoStructify) app.add_transform(AutoStructify)
......
...@@ -39,6 +39,7 @@ def manual_seed(seed: int) -> None: ... ...@@ -39,6 +39,7 @@ def manual_seed(seed: int) -> None: ...
def memory_allocated(device: Optional[_device_t]=...) -> int: ... def memory_allocated(device: Optional[_device_t]=...) -> int: ...
def max_memory_allocated(device: Optional[_device_t]=...) -> int: ... def max_memory_allocated(device: Optional[_device_t]=...) -> int: ...
def reset_max_memory_allocated(device: Optional[_device_t]=...) -> None: ... def reset_max_memory_allocated(device: Optional[_device_t]=...) -> None: ...
def reset_peak_memory_stats(device: Union[_device_t, int] = None) -> None: ...
def memory_cached(device: Optional[_device_t]=...) -> int: ... def memory_cached(device: Optional[_device_t]=...) -> int: ...
def max_memory_cached(device: Optional[_device_t]=...) -> int: ... def max_memory_cached(device: Optional[_device_t]=...) -> int: ...
def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ... def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ...
......
...@@ -8,7 +8,7 @@ T = TypeVar('T') ...@@ -8,7 +8,7 @@ T = TypeVar('T')
class Dataset(Generic[T_co]): class Dataset(Generic[T_co]):
def __getitem__(self, index: int) -> T_co: ... def __getitem__(self, index: int) -> T_co: ...
def __len__(self) -> int: ... def __len__(self) -> int: ...
def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ... def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]': ...
class IterableDataset(Dataset[T_co]): class IterableDataset(Dataset[T_co]):
def __iter__(self) -> Iterable[T_co]: ... def __iter__(self) -> Iterable[T_co]: ...
......
...@@ -6,6 +6,6 @@ from . import Sampler, Dataset ...@@ -6,6 +6,6 @@ from . import Sampler, Dataset
T_co = TypeVar('T_co', covariant=True) T_co = TypeVar('T_co', covariant=True)
class DistributedSampler(Sampler[T_co]): class DistributedSampler(Sampler[T_co]):
def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=...): ... def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=...): ...
def __iter__(self) -> Iterator[int]: ... def __iter__(self) -> Iterator[T_co]: ...
def __len__(self) -> int: ... def __len__(self) -> int: ...
def set_epoch(self, epoch: int) -> None: ... def set_epoch(self, epoch: int) -> None: ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment