Unverified Commit bd5d0496 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Lint flattenparams (#320)

* working around broken mypy
parent a6ed6da8
# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn as nn
class FlattenParamsWrapper(nn.Module):
......@@ -28,9 +27,7 @@ class FlattenParamsWrapper(nn.Module):
appearing in the given list (default: flatten all parameters)
"""
def __init__(
self, module: nn.Module, param_list: Optional[List[nn.Parameter]] = None
):
def __init__(self, module: nn.Module, param_list: Optional[List[nn.Parameter]] = None):
super().__init__()
self.module = module
......@@ -74,9 +71,7 @@ class FlattenParamsWrapper(nn.Module):
param_shapes.append(p.size())
del shared_param_memo
assert (
len(set(p.dtype for p in params)) <= 1
), "expects all parameters in module to have same dtype"
assert len(set(p.dtype for p in params)) <= 1, "expects all parameters in module to have same dtype"
# store the info for unflatten
self._param_infos = tuple(param_infos)
......@@ -97,12 +92,7 @@ class FlattenParamsWrapper(nn.Module):
delattr(m, n)
def _get_param_views(self) -> Generator:
return (
t.view(s)
for (t, s) in zip(
self.flat_param.split(self._param_numels), self._param_shapes
)
)
return (t.view(s) for (t, s) in zip(self.flat_param.split(self._param_numels), self._param_shapes))
def _unflatten_params(self) -> None:
ps = self._get_param_views()
......@@ -137,9 +127,7 @@ class FlattenParamsWrapper(nn.Module):
except AttributeError:
return getattr(self.module, name) # fallback to wrapped module
def state_dict(
self, prefix: str = "", keep_vars: bool = False,
) -> OrderedDict[str, Tensor]:
def state_dict(self, prefix: str = "", keep_vars: bool = False) -> "OrderedDict[str, Tensor]": # type: ignore
"""Return an unflattened state_dict."""
with self.unflatten_params():
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
......@@ -148,9 +136,7 @@ class FlattenParamsWrapper(nn.Module):
"""Return the flattened state_dict."""
return super().state_dict(*args, **kwargs)
def load_state_dict(
self, state_dict: Dict[str, Any], *args: Any, **kwargs: Any
) -> None:
def load_state_dict(self, state_dict: Dict[str, Any], *args: Any, **kwargs: Any) -> None:
if "flat_param" in state_dict:
super().load_state_dict(state_dict, strict=True)
else:
......
......@@ -66,7 +66,7 @@ class IdentityLayer(torch.nn.Module):
def set_random_seed(seed: int) -> None:
"""Set random seed for reproducability."""
"""Set random seed for reproducibility."""
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
......@@ -388,9 +388,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
try:
torch.testing.assert_allclose(a, b)
# assert_allclose doesn't strictly test shape, dtype and device
shape_dtype_device_match = (
a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
)
shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
assert shape_dtype_device_match
return True
except AssertionError as e:
......@@ -400,4 +398,3 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
return False
else:
return a == b
......@@ -10,6 +10,7 @@ Test FlattenParamsWrapper
import unittest
import torch
from fairscale.nn import FlattenParamsWrapper
from fairscale.utils.testing import objects_are_equal
......@@ -18,11 +19,7 @@ class TestFlattenParams(unittest.TestCase):
def _get_transformer(self, seed=0):
torch.manual_seed(seed) # keep everything deterministic
module = torch.nn.Transformer(
d_model=32,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=128,
dropout=0.1,
d_model=32, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=128, dropout=0.1,
)
module.register_buffer("dummy_buffer", torch.tensor(1.0))
return module
......@@ -70,10 +67,7 @@ class TestFlattenParams(unittest.TestCase):
module = self._get_transformer()
num_params = sum(p.numel() for p in module.parameters())
params_to_flatten = (
list(module.encoder.layers[1].parameters())
+ list(module.decoder.layers[0].parameters())
)
params_to_flatten = list(module.encoder.layers[1].parameters()) + list(module.decoder.layers[0].parameters())
num_params_to_flatten = sum(p.numel() for p in params_to_flatten)
module = FlattenParamsWrapper(module, param_list=params_to_flatten)
......@@ -92,9 +86,7 @@ class TestFlattenParams(unittest.TestCase):
orig_dtype = params_to_flatten[0].dtype
new_dtype = torch.float32 if orig_dtype == torch.float16 else torch.float16
assert module.flat_param.dtype == orig_dtype
assert all(
p.dtype == orig_dtype for p in module.encoder.layers[0].parameters()
)
assert all(p.dtype == orig_dtype for p in module.encoder.layers[0].parameters())
module = module.to(dtype=new_dtype)
assert module.flat_param.dtype == new_dtype
assert all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment