Unverified Commit a6ed6da8 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix] lint/typing in FlattenParamsWrapper (#318)

parent 35fdf537
# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.
from collections import namedtuple
from collections import OrderedDict
from contextlib import contextmanager
from typing import List, Optional
from typing import Any, Dict, Generator, List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
class FlattenParamsWrapper(nn.Module):
......@@ -36,32 +37,32 @@ class FlattenParamsWrapper(nn.Module):
if param_list is not None:
assert len(param_list) > 0, "param_list can't be empty"
else:
param_list = module.parameters()
param_list = set(param_list)
param_list = list(module.parameters())
param_set = set(param_list)
# convert from list of Parameters to set of (Module, name) tuples, which
# will survive in case the Parameter instances are reset
self._param_list = set()
self._param_set = set()
for m in self.modules():
for n, p in m.named_parameters(recurse=False):
if p in param_list:
self._param_list.add((m, n))
if p in param_set:
self._param_set.add((m, n))
self._flatten_params()
# register the views as plain attributes
self._unflatten_params_as_views()
def _flatten_params(self):
def _flatten_params(self) -> None:
param_infos = []
shared_param_memo = {}
shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {}
shared_param_infos = []
params = []
param_numels = []
param_shapes = []
for m in self.modules():
for n, p in m.named_parameters(recurse=False):
if p is not None and (m, n) in self._param_list:
if p is not None and (m, n) in self._param_set:
if p in shared_param_memo:
shared_m, shared_n = shared_param_memo[p]
shared_param_infos.append((m, n, shared_m, shared_n))
......@@ -95,7 +96,7 @@ class FlattenParamsWrapper(nn.Module):
for m, n, _, _ in self._shared_param_infos:
delattr(m, n)
def _get_param_views(self):
def _get_param_views(self) -> Generator:
return (
t.view(s)
for (t, s) in zip(
......@@ -103,7 +104,7 @@ class FlattenParamsWrapper(nn.Module):
)
)
def _unflatten_params(self):
def _unflatten_params(self) -> None:
ps = self._get_param_views()
for (m, n), p in zip(self._param_infos, ps):
if hasattr(m, n):
......@@ -115,7 +116,7 @@ class FlattenParamsWrapper(nn.Module):
m.register_parameter(n, getattr(shared_m, shared_n))
del self.flat_param
def _unflatten_params_as_views(self):
def _unflatten_params_as_views(self) -> None:
ps = self._get_param_views()
for (m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr
......@@ -123,33 +124,39 @@ class FlattenParamsWrapper(nn.Module):
setattr(m, n, getattr(shared_m, shared_n))
@contextmanager
def unflatten_params(self):
def unflatten_params(self) -> Generator:
self._unflatten_params()
yield
self._flatten_params()
self._unflatten_params_as_views()
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.module, name) # fallback to wrapped module
def state_dict(self, *args, unflatten_params=True, **kwargs):
if unflatten_params:
with self.unflatten_params():
return self.module.state_dict()
else:
return super().state_dict()
def state_dict(
self, prefix: str = "", keep_vars: bool = False,
) -> OrderedDict[str, Tensor]:
"""Return an unflattened state_dict."""
with self.unflatten_params():
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Return the flattened state_dict."""
return super().state_dict(*args, **kwargs)
def load_state_dict(self, state_dict, *args, **kwargs):
def load_state_dict(
self, state_dict: Dict[str, Any], *args: Any, **kwargs: Any
) -> None:
if "flat_param" in state_dict:
super().load_state_dict(state_dict, strict=True)
else:
with self.unflatten_params():
return self.module.load_state_dict(state_dict, *args, **kwargs)
def forward(self, *inputs, **kwinputs):
def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
self._unflatten_params_as_views()
return self.module(*inputs, **kwinputs)
......@@ -366,7 +366,7 @@ class GPT2(nn.Module):
return self.clf_head(h), logits
def objects_are_equal(a, b, raise_exception=False) -> bool:
def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
"""
Test that two objects are equal. Tensors are compared to ensure matching
size, dtype, device and values.
......
......@@ -27,6 +27,7 @@ from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
from . import cuda as cuda
from . import optim as optim
from . import nn as nn
from . import testing as testing
#MODIFIED BY TORCHGPIPE
from . import backends
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#MODIFIED FOR FlattenParamsWrapper
from typing import Any
def assert_allclose(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ...
#END
......@@ -153,7 +153,7 @@ class TestFlattenParams(unittest.TestCase):
flat_module = FlattenParamsWrapper(flat_module)
ref_output = self._get_output(flat_module)
flat_state_dict = flat_module.state_dict(unflatten_params=False)
flat_state_dict = flat_module.flat_state_dict()
new_module = self._get_shared_params_transformer(seed=1234)
new_module = FlattenParamsWrapper(new_module)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment