Unverified Commit 4a874a6b authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[cleanup] fix pre-commit mypy issues (#87)

parent d16e9f61
......@@ -5,7 +5,7 @@ import argparse
import math
import os
import time
from typing import Any, List, Union, cast
from typing import Any, List, cast
import torch
import torch.distributed as dist
......@@ -62,7 +62,7 @@ def train(
torch.cuda.reset_peak_memory_stats(rank)
# Shard the optimizer
optimizer: Union[OSS, OPTIM] = OSS(
optimizer: torch.optim.Optimizer = OSS(
params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9
) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
......
# type: ignore
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
......@@ -73,10 +74,7 @@ html_static_path = ["_static"]
def setup(app):
app.add_config_value(
"recommonmark_config",
{
"url_resolver": lambda url: github_doc_root + url,
"auto_toc_tree_section": "Contents",
},
{"url_resolver": lambda url: github_doc_root + url, "auto_toc_tree_section": "Contents"},
True,
)
app.add_transform(AutoStructify)
......
......@@ -39,6 +39,7 @@ def manual_seed(seed: int) -> None: ...
def memory_allocated(device: Optional[_device_t]=...) -> int: ...
def max_memory_allocated(device: Optional[_device_t]=...) -> int: ...
def reset_max_memory_allocated(device: Optional[_device_t]=...) -> None: ...
def reset_peak_memory_stats(device: Union[_device_t, int] = None) -> None: ...
def memory_cached(device: Optional[_device_t]=...) -> int: ...
def max_memory_cached(device: Optional[_device_t]=...) -> int: ...
def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ...
......
......@@ -8,7 +8,7 @@ T = TypeVar('T')
class Dataset(Generic[T_co]):
def __getitem__(self, index: int) -> T_co: ...
def __len__(self) -> int: ...
def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ...
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]': ...
class IterableDataset(Dataset[T_co]):
def __iter__(self) -> Iterable[T_co]: ...
......
......@@ -6,6 +6,6 @@ from . import Sampler, Dataset
T_co = TypeVar('T_co', covariant=True)
class DistributedSampler(Sampler[T_co]):
def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=...): ...
def __iter__(self) -> Iterator[int]: ...
def __iter__(self) -> Iterator[T_co]: ...
def __len__(self) -> int: ...
def set_epoch(self, epoch: int) -> None: ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment