"...text-generation-inference.git" did not exist on "895a341d064c9930b2a9bd60cff0df42f91b52fa"
Unverified Commit a6ed6da8 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix] lint/typing in FlattenParamsWrapper (#318)

parent 35fdf537
# Copyright (c) Tongzhou Wang # Copyright (c) Tongzhou Wang
# Licensed under the MIT License. # Licensed under the MIT License.
from collections import namedtuple from collections import OrderedDict
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional from typing import Any, Dict, Generator, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
class FlattenParamsWrapper(nn.Module): class FlattenParamsWrapper(nn.Module):
...@@ -36,32 +37,32 @@ class FlattenParamsWrapper(nn.Module): ...@@ -36,32 +37,32 @@ class FlattenParamsWrapper(nn.Module):
if param_list is not None: if param_list is not None:
assert len(param_list) > 0, "param_list can't be empty" assert len(param_list) > 0, "param_list can't be empty"
else: else:
param_list = module.parameters() param_list = list(module.parameters())
param_list = set(param_list) param_set = set(param_list)
# convert from list of Parameters to set of (Module, name) tuples, which # convert from list of Parameters to set of (Module, name) tuples, which
# will survive in case the Parameter instances are reset # will survive in case the Parameter instances are reset
self._param_list = set() self._param_set = set()
for m in self.modules(): for m in self.modules():
for n, p in m.named_parameters(recurse=False): for n, p in m.named_parameters(recurse=False):
if p in param_list: if p in param_set:
self._param_list.add((m, n)) self._param_set.add((m, n))
self._flatten_params() self._flatten_params()
# register the views as plain attributes # register the views as plain attributes
self._unflatten_params_as_views() self._unflatten_params_as_views()
def _flatten_params(self): def _flatten_params(self) -> None:
param_infos = [] param_infos = []
shared_param_memo = {} shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {}
shared_param_infos = [] shared_param_infos = []
params = [] params = []
param_numels = [] param_numels = []
param_shapes = [] param_shapes = []
for m in self.modules(): for m in self.modules():
for n, p in m.named_parameters(recurse=False): for n, p in m.named_parameters(recurse=False):
if p is not None and (m, n) in self._param_list: if p is not None and (m, n) in self._param_set:
if p in shared_param_memo: if p in shared_param_memo:
shared_m, shared_n = shared_param_memo[p] shared_m, shared_n = shared_param_memo[p]
shared_param_infos.append((m, n, shared_m, shared_n)) shared_param_infos.append((m, n, shared_m, shared_n))
...@@ -95,7 +96,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -95,7 +96,7 @@ class FlattenParamsWrapper(nn.Module):
for m, n, _, _ in self._shared_param_infos: for m, n, _, _ in self._shared_param_infos:
delattr(m, n) delattr(m, n)
def _get_param_views(self): def _get_param_views(self) -> Generator:
return ( return (
t.view(s) t.view(s)
for (t, s) in zip( for (t, s) in zip(
...@@ -103,7 +104,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -103,7 +104,7 @@ class FlattenParamsWrapper(nn.Module):
) )
) )
def _unflatten_params(self): def _unflatten_params(self) -> None:
ps = self._get_param_views() ps = self._get_param_views()
for (m, n), p in zip(self._param_infos, ps): for (m, n), p in zip(self._param_infos, ps):
if hasattr(m, n): if hasattr(m, n):
...@@ -115,7 +116,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -115,7 +116,7 @@ class FlattenParamsWrapper(nn.Module):
m.register_parameter(n, getattr(shared_m, shared_n)) m.register_parameter(n, getattr(shared_m, shared_n))
del self.flat_param del self.flat_param
def _unflatten_params_as_views(self): def _unflatten_params_as_views(self) -> None:
ps = self._get_param_views() ps = self._get_param_views()
for (m, n), p in zip(self._param_infos, ps): for (m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr setattr(m, n, p) # This will set as plain attr
...@@ -123,33 +124,39 @@ class FlattenParamsWrapper(nn.Module): ...@@ -123,33 +124,39 @@ class FlattenParamsWrapper(nn.Module):
setattr(m, n, getattr(shared_m, shared_n)) setattr(m, n, getattr(shared_m, shared_n))
@contextmanager @contextmanager
def unflatten_params(self): def unflatten_params(self) -> Generator:
self._unflatten_params() self._unflatten_params()
yield yield
self._flatten_params() self._flatten_params()
self._unflatten_params_as_views() self._unflatten_params_as_views()
def __getattr__(self, name): def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module.""" """Forward missing attributes to wrapped module."""
try: try:
return super().__getattr__(name) # defer to nn.Module's logic return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError: except AttributeError:
return getattr(self.module, name) # fallback to wrapped module return getattr(self.module, name) # fallback to wrapped module
def state_dict(self, *args, unflatten_params=True, **kwargs): def state_dict(
if unflatten_params: self, prefix: str = "", keep_vars: bool = False,
) -> OrderedDict[str, Tensor]:
"""Return an unflattened state_dict."""
with self.unflatten_params(): with self.unflatten_params():
return self.module.state_dict() return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
else:
return super().state_dict() def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Return the flattened state_dict."""
return super().state_dict(*args, **kwargs)
def load_state_dict(self, state_dict, *args, **kwargs): def load_state_dict(
self, state_dict: Dict[str, Any], *args: Any, **kwargs: Any
) -> None:
if "flat_param" in state_dict: if "flat_param" in state_dict:
super().load_state_dict(state_dict, strict=True) super().load_state_dict(state_dict, strict=True)
else: else:
with self.unflatten_params(): with self.unflatten_params():
return self.module.load_state_dict(state_dict, *args, **kwargs) return self.module.load_state_dict(state_dict, *args, **kwargs)
def forward(self, *inputs, **kwinputs): def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
self._unflatten_params_as_views() self._unflatten_params_as_views()
return self.module(*inputs, **kwinputs) return self.module(*inputs, **kwinputs)
...@@ -366,7 +366,7 @@ class GPT2(nn.Module): ...@@ -366,7 +366,7 @@ class GPT2(nn.Module):
return self.clf_head(h), logits return self.clf_head(h), logits
def objects_are_equal(a, b, raise_exception=False) -> bool: def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
""" """
Test that two objects are equal. Tensors are compared to ensure matching Test that two objects are equal. Tensors are compared to ensure matching
size, dtype, device and values. size, dtype, device and values.
......
...@@ -27,6 +27,7 @@ from .autograd import no_grad as no_grad, enable_grad as enable_grad, \ ...@@ -27,6 +27,7 @@ from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
from . import cuda as cuda from . import cuda as cuda
from . import optim as optim from . import optim as optim
from . import nn as nn from . import nn as nn
from . import testing as testing
#MODIFIED BY TORCHGPIPE #MODIFIED BY TORCHGPIPE
from . import backends from . import backends
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#MODIFIED FOR FlattenParamsWrapper
from typing import Any
def assert_allclose(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ...
#END
...@@ -153,7 +153,7 @@ class TestFlattenParams(unittest.TestCase): ...@@ -153,7 +153,7 @@ class TestFlattenParams(unittest.TestCase):
flat_module = FlattenParamsWrapper(flat_module) flat_module = FlattenParamsWrapper(flat_module)
ref_output = self._get_output(flat_module) ref_output = self._get_output(flat_module)
flat_state_dict = flat_module.state_dict(unflatten_params=False) flat_state_dict = flat_module.flat_state_dict()
new_module = self._get_shared_params_transformer(seed=1234) new_module = self._get_shared_params_transformer(seed=1234)
new_module = FlattenParamsWrapper(new_module) new_module = FlattenParamsWrapper(new_module)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment