"docs/en/_static/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "fdeee889589df413e368b05fd702b5c3f76ac7d3"
Unverified Commit bd5d0496 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Lint flattenparams (#320)

* working around broken mypy
parent a6ed6da8
# Copyright (c) Tongzhou Wang # Copyright (c) Tongzhou Wang
# Licensed under the MIT License. # Licensed under the MIT License.
from collections import OrderedDict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Optional, Tuple
import torch import torch
import torch.nn as nn
from torch import Tensor from torch import Tensor
import torch.nn as nn
class FlattenParamsWrapper(nn.Module): class FlattenParamsWrapper(nn.Module):
...@@ -28,9 +27,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -28,9 +27,7 @@ class FlattenParamsWrapper(nn.Module):
appearing in the given list (default: flatten all parameters) appearing in the given list (default: flatten all parameters)
""" """
def __init__( def __init__(self, module: nn.Module, param_list: Optional[List[nn.Parameter]] = None):
self, module: nn.Module, param_list: Optional[List[nn.Parameter]] = None
):
super().__init__() super().__init__()
self.module = module self.module = module
...@@ -74,9 +71,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -74,9 +71,7 @@ class FlattenParamsWrapper(nn.Module):
param_shapes.append(p.size()) param_shapes.append(p.size())
del shared_param_memo del shared_param_memo
assert ( assert len(set(p.dtype for p in params)) <= 1, "expects all parameters in module to have same dtype"
len(set(p.dtype for p in params)) <= 1
), "expects all parameters in module to have same dtype"
# store the info for unflatten # store the info for unflatten
self._param_infos = tuple(param_infos) self._param_infos = tuple(param_infos)
...@@ -97,12 +92,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -97,12 +92,7 @@ class FlattenParamsWrapper(nn.Module):
delattr(m, n) delattr(m, n)
def _get_param_views(self) -> Generator: def _get_param_views(self) -> Generator:
return ( return (t.view(s) for (t, s) in zip(self.flat_param.split(self._param_numels), self._param_shapes))
t.view(s)
for (t, s) in zip(
self.flat_param.split(self._param_numels), self._param_shapes
)
)
def _unflatten_params(self) -> None: def _unflatten_params(self) -> None:
ps = self._get_param_views() ps = self._get_param_views()
...@@ -137,9 +127,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -137,9 +127,7 @@ class FlattenParamsWrapper(nn.Module):
except AttributeError: except AttributeError:
return getattr(self.module, name) # fallback to wrapped module return getattr(self.module, name) # fallback to wrapped module
def state_dict( def state_dict(self, prefix: str = "", keep_vars: bool = False) -> "OrderedDict[str, Tensor]": # type: ignore
self, prefix: str = "", keep_vars: bool = False,
) -> OrderedDict[str, Tensor]:
"""Return an unflattened state_dict.""" """Return an unflattened state_dict."""
with self.unflatten_params(): with self.unflatten_params():
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
...@@ -148,9 +136,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -148,9 +136,7 @@ class FlattenParamsWrapper(nn.Module):
"""Return the flattened state_dict.""" """Return the flattened state_dict."""
return super().state_dict(*args, **kwargs) return super().state_dict(*args, **kwargs)
def load_state_dict( def load_state_dict(self, state_dict: Dict[str, Any], *args: Any, **kwargs: Any) -> None:
self, state_dict: Dict[str, Any], *args: Any, **kwargs: Any
) -> None:
if "flat_param" in state_dict: if "flat_param" in state_dict:
super().load_state_dict(state_dict, strict=True) super().load_state_dict(state_dict, strict=True)
else: else:
......
...@@ -66,7 +66,7 @@ class IdentityLayer(torch.nn.Module): ...@@ -66,7 +66,7 @@ class IdentityLayer(torch.nn.Module):
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int) -> None:
"""Set random seed for reproducability.""" """Set random seed for reproducibility."""
random.seed(seed) random.seed(seed)
numpy.random.seed(seed) numpy.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
...@@ -388,9 +388,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool: ...@@ -388,9 +388,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
try: try:
torch.testing.assert_allclose(a, b) torch.testing.assert_allclose(a, b)
# assert_allclose doesn't strictly test shape, dtype and device # assert_allclose doesn't strictly test shape, dtype and device
shape_dtype_device_match = ( shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
)
assert shape_dtype_device_match assert shape_dtype_device_match
return True return True
except AssertionError as e: except AssertionError as e:
...@@ -400,4 +398,3 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool: ...@@ -400,4 +398,3 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
return False return False
else: else:
return a == b return a == b
...@@ -10,6 +10,7 @@ Test FlattenParamsWrapper ...@@ -10,6 +10,7 @@ Test FlattenParamsWrapper
import unittest import unittest
import torch import torch
from fairscale.nn import FlattenParamsWrapper from fairscale.nn import FlattenParamsWrapper
from fairscale.utils.testing import objects_are_equal from fairscale.utils.testing import objects_are_equal
...@@ -18,11 +19,7 @@ class TestFlattenParams(unittest.TestCase): ...@@ -18,11 +19,7 @@ class TestFlattenParams(unittest.TestCase):
def _get_transformer(self, seed=0): def _get_transformer(self, seed=0):
torch.manual_seed(seed) # keep everything deterministic torch.manual_seed(seed) # keep everything deterministic
module = torch.nn.Transformer( module = torch.nn.Transformer(
d_model=32, d_model=32, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=128, dropout=0.1,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=128,
dropout=0.1,
) )
module.register_buffer("dummy_buffer", torch.tensor(1.0)) module.register_buffer("dummy_buffer", torch.tensor(1.0))
return module return module
...@@ -70,10 +67,7 @@ class TestFlattenParams(unittest.TestCase): ...@@ -70,10 +67,7 @@ class TestFlattenParams(unittest.TestCase):
module = self._get_transformer() module = self._get_transformer()
num_params = sum(p.numel() for p in module.parameters()) num_params = sum(p.numel() for p in module.parameters())
params_to_flatten = ( params_to_flatten = list(module.encoder.layers[1].parameters()) + list(module.decoder.layers[0].parameters())
list(module.encoder.layers[1].parameters())
+ list(module.decoder.layers[0].parameters())
)
num_params_to_flatten = sum(p.numel() for p in params_to_flatten) num_params_to_flatten = sum(p.numel() for p in params_to_flatten)
module = FlattenParamsWrapper(module, param_list=params_to_flatten) module = FlattenParamsWrapper(module, param_list=params_to_flatten)
...@@ -92,9 +86,7 @@ class TestFlattenParams(unittest.TestCase): ...@@ -92,9 +86,7 @@ class TestFlattenParams(unittest.TestCase):
orig_dtype = params_to_flatten[0].dtype orig_dtype = params_to_flatten[0].dtype
new_dtype = torch.float32 if orig_dtype == torch.float16 else torch.float16 new_dtype = torch.float32 if orig_dtype == torch.float16 else torch.float16
assert module.flat_param.dtype == orig_dtype assert module.flat_param.dtype == orig_dtype
assert all( assert all(p.dtype == orig_dtype for p in module.encoder.layers[0].parameters())
p.dtype == orig_dtype for p in module.encoder.layers[0].parameters()
)
module = module.to(dtype=new_dtype) module = module.to(dtype=new_dtype)
assert module.flat_param.dtype == new_dtype assert module.flat_param.dtype == new_dtype
assert all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters()) assert all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment