Unverified Commit d60fc284 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] supporting multiple flatten parameter groups (step 1 and step 1.5) (#708)



* refactoring FlattenParamWrapper

- use a FlatParameter class to encapsulate the logic of
  flattening and expanding into views.
- this will make it easier to have multiple groups of flatten
  parameters

* fixed testing context issues for both temp files and temp dirs

* fixing test_fsdp_metadata

* fix pickling of FlatParameter

* fixed test_fsdp_optimizer_utils.py

* minor

* fix assert

* lint

* remove nesting from the test

* step 1.5: remove the code related unnecessary nesting support in FPW

* Update fairscale/nn/misc/flatten_params_wrapper.py
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>

* address comment
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 1e4a503c
...@@ -4,10 +4,12 @@ ...@@ -4,10 +4,12 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states.""" """These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy import copy
from typing import Any, Dict, Generator, List, Tuple from typing import Any, Dict, Generator, List, Tuple, cast
import torch import torch
from fairscale.nn.misc import FlattenParamsWrapper
# These return keys are used by fairseq. To change, add @sshleifer as a reviewer. # These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"} UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"}
...@@ -96,8 +98,9 @@ def _unflatten_optim_state( ...@@ -96,8 +98,9 @@ def _unflatten_optim_state(
# we check that these are identical across workers and then take the first # we check that these are identical across workers and then take the first
non_tensor_state = [_extract_non_tensor_state(combined_state, id) for id in combined_state] non_tensor_state = [_extract_non_tensor_state(combined_state, id) for id in combined_state]
# local corresponds to flattened, global corresponds to unflattened # Local corresponds to flattened, global corresponds to unflattened.
num_global_params = [len(m._param_numels) for m in instance_list] # type: ignore # Casting needed only for mypy.
num_global_params = [cast(int, m.num_params_managed) for m in instance_list]
global_to_local_id = {} global_to_local_id = {}
for local_id, num_unflat in enumerate(num_global_params): for local_id, num_unflat in enumerate(num_global_params):
for _ in range(num_unflat): for _ in range(num_unflat):
...@@ -126,7 +129,8 @@ def _unflatten_optim_state( ...@@ -126,7 +129,8 @@ def _unflatten_optim_state(
assert isinstance(v, list), f"got {k}: {v} for {local_id}" assert isinstance(v, list), f"got {k}: {v} for {local_id}"
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])]
flat_buffer = torch.cat(v_unpad) flat_buffer = torch.cat(v_unpad)
param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore # Casting needed only for mypy.
param_views: Generator = cast(FlattenParamsWrapper, instance_list[local_id]).get_param_views([flat_buffer])
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views): for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}" assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view unflat_state[global_id][k] = param_view
...@@ -154,9 +158,10 @@ def build_unflat_state_dict( ...@@ -154,9 +158,10 @@ def build_unflat_state_dict(
singleton_state[local_id] = {buffer_name: [x] for buffer_name, x in v.items() if is_singleton_tensor(x)} singleton_state[local_id] = {buffer_name: [x] for buffer_name, x in v.items() if is_singleton_tensor(x)}
# local ids are in the current state, global_ids will be in returned state. # local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(state, instance_list, world_pad_info, singleton_state) unflat_state, global_to_local_id = _unflatten_optim_state(state, instance_list, world_pad_info, singleton_state)
# Since there are no tensors in param_groups, deepcopy is fine # Since there are no tensors in param_groups, deepcopy is fine.
param_groups = copy.deepcopy(param_groups) param_groups = copy.deepcopy(param_groups)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore # Casting needed only for mypy.
num_params = sum([cast(int, m.num_params_managed) for m in instance_list])
param_groups[0]["params"] = list(range(num_params)) param_groups[0]["params"] = list(range(num_params))
unflat_optim_state_dict = { unflat_optim_state_dict = {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted "state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
......
...@@ -833,7 +833,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -833,7 +833,7 @@ class FullyShardedDataParallel(nn.Module):
# latter may contain padding. # latter may contain padding.
assert len(self.params) == 1 assert len(self.params) == 1
assert isinstance(self.module, FlattenParamsWrapper) assert isinstance(self.module, FlattenParamsWrapper)
stack.enter_context(self.module.unflatten_params(recurse=False, flat_param=self.params[0])) stack.enter_context(self.module.unflatten_params(flat_param=self.params[0]))
try: try:
yield yield
finally: finally:
...@@ -1534,7 +1534,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1534,7 +1534,7 @@ class FullyShardedDataParallel(nn.Module):
# There are as many sharded parameters as there parameters in the # There are as many sharded parameters as there parameters in the
# consolidated model, so we only need to export how to reshape the # consolidated model, so we only need to export how to reshape the
# parameters to their orginal shape and take care of the padding # parameters to their orginal shape and take care of the padding
if not hasattr(m, "_param_numels"): if not m.flatten_parameters:
params_metadata.append( params_metadata.append(
{ {
"fsdp_path": _clean_path(path), "fsdp_path": _clean_path(path),
...@@ -1563,8 +1563,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1563,8 +1563,11 @@ class FullyShardedDataParallel(nn.Module):
"is_flat": True, "is_flat": True,
"num_padded": m.numel_padded_per_param, "num_padded": m.numel_padded_per_param,
"param_names": param_names, "param_names": param_names,
"param_shapes": m._param_shapes, # TODO (Min): we don't want to access the private _param_shapes and
"param_numels": m._param_numels, # _param_numels here. We want to dump metadata from FPW when there are
# multiple groups of params.
"param_shapes": m._fsdp_wrapped_module.flat_param._param_shapes, # type: ignore
"param_numels": m._fsdp_wrapped_module.flat_param._param_numels, # type: ignore
"no_broadcast_optim_state": m.no_broadcast_optim_state, "no_broadcast_optim_state": m.no_broadcast_optim_state,
} }
) )
...@@ -1589,7 +1592,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1589,7 +1592,7 @@ class FullyShardedDataParallel(nn.Module):
If this behavior is not correct for your module (for instance if buffers If this behavior is not correct for your module (for instance if buffers
needs to be reduced instead), you can disable it with `with_module_buffers=False`. needs to be reduced instead), you can disable it with `with_module_buffers=False`.
This method is very useful to re-assemble checkpoints of shards without This method is used to re-assemble checkpoints of shards without
having to instantiate FSDP wrappers with the world size originally used having to instantiate FSDP wrappers with the world size originally used
to save the shards. to save the shards.
""" """
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Tongzhou Wang # Copyright (c) Tongzhou Wang
# Licensed under the MIT License. # Licensed under the MIT License.
from contextlib import ExitStack, contextmanager from contextlib import contextmanager
from itertools import chain
import typing import typing
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -15,6 +21,72 @@ if TYPE_CHECKING: ...@@ -15,6 +21,72 @@ if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401 from collections import OrderedDict # noqa: F401
class FlatParameter(nn.Parameter):
""" A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed.
"""
def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) -> "FlatParameter":
""" Make an object using the parent's __new__ function. """
# A empty of non-list input doesn't make sense.
if not isinstance(params, (list, tuple)) or len(params) == 0:
raise ValueError("An non-empty list or tuple argument is needed")
# Normally, all items are Parameters. But during pickling, we will have a single
# Tensor as the input and later in __init__, the correct _param_numels and _param_shapes
# are set.
if not all(isinstance(p, (nn.Parameter, Tensor)) for p in params):
raise ValueError("List items need to be Parameter types")
# Flattening involves (1) making a tensor flat (i.e. single dimensional) and (2) making a module
# heirarchy flat (using a single tensor to replace a tree of tensors). Therefore,
# adding back nesting and heirarchy is counter-productive. If nesting is encountered
# in the future, the reasonable thing to do is likely for the top level FlatParameter to
# absorb the nested one and keep the result flat, free from hierarchy.
if any(isinstance(p, FlatParameter) for p in params):
raise ValueError("Nesting FlatParameter is not supported")
data = torch.cat([p.detach().reshape(-1) if isinstance(p, nn.Parameter) else p.reshape(-1) for p in params], 0)
return super(FlatParameter, cls).__new__(cls, data, requires_grad=requires_grad)
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
""" Initialize the _param_numels and _param_shapes lists. """
self._param_numels = [p.numel() for p in params]
assert self.numel() <= sum(
self._param_numels
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
self._param_shapes = [p.size() for p in params]
def get_param_views(self, external_data: Optional[Tensor] = None) -> Generator[Tensor, None, None]:
""" Return a generator of views that map to the original parameters. """
# Note, self.data could be sharded, so its numel is <= to the sum.
assert self.data.numel() <= sum(
self._param_numels
), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}"
data = external_data if external_data is not None else self
if data.numel() != sum(self._param_numels):
raise ValueError(
f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}"
)
return (t.view(s) for (t, s) in zip(data.split(self._param_numels), self._param_shapes))
def __setstate__(self, state: Tuple[Any, Any]) -> None:
""" Use by pickle to set the internal states. """
self._param_numels, self._param_shapes = state
assert self.numel() <= sum(
self._param_numels
), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}"
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any]:
""" Support pickling between ranks. """
return (
FlatParameter, # Callable
([self.data], self.requires_grad), # Args to the callable above
(self._param_numels, self._param_shapes), # Args to __setstate__
)
class FlattenParamsWrapper(nn.Module): class FlattenParamsWrapper(nn.Module):
""" """
A wrapper for transparently flattening a Module's parameters. A wrapper for transparently flattening a Module's parameters.
...@@ -24,6 +96,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -24,6 +96,7 @@ class FlattenParamsWrapper(nn.Module):
- supports shared parameters - supports shared parameters
- handles state_dict/load_state_dict transparently - handles state_dict/load_state_dict transparently
- is renamed to FlattenParamsWrapper - is renamed to FlattenParamsWrapper
- refactored to use the FlatParameter class
[1] https://github.com/SsnL/PyTorch-Reparam-Module [1] https://github.com/SsnL/PyTorch-Reparam-Module
...@@ -44,6 +117,10 @@ class FlattenParamsWrapper(nn.Module): ...@@ -44,6 +117,10 @@ class FlattenParamsWrapper(nn.Module):
param_list = list(module.parameters()) param_list = list(module.parameters())
param_set = set(param_list) param_set = set(param_list)
# Since the parameters will be deleted, let's record the number original parameters
# managed by this class.
self.num_params_managed = len(param_set)
# convert from list of Parameters to set of (Module, name) tuples, which # convert from list of Parameters to set of (Module, name) tuples, which
# will survive in case the Parameter instances are reset # will survive in case the Parameter instances are reset
self._param_set = set() self._param_set = set()
...@@ -52,6 +129,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -52,6 +129,7 @@ class FlattenParamsWrapper(nn.Module):
if p in param_set: if p in param_set:
self._param_set.add((m, n)) self._param_set.add((m, n))
# TODO (Min): double check we handle the special case of module without any params.
self._flatten_params() self._flatten_params()
# Register hook to be called after state_dict() to remove the # Register hook to be called after state_dict() to remove the
...@@ -68,14 +146,15 @@ class FlattenParamsWrapper(nn.Module): ...@@ -68,14 +146,15 @@ class FlattenParamsWrapper(nn.Module):
def module(self) -> nn.Module: def module(self) -> nn.Module:
return self._fpw_module return self._fpw_module
def _init_flatten_params(self) -> List[Tensor]: def _init_flatten_params(self) -> List[nn.Parameter]:
""" Build metadata for need-to-be-flatten parameters and returns a list
contains the need-to-be-flatten parameters.
"""
param_infos = [] param_infos = []
param_full_infos = [] param_full_infos = []
shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {} shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {}
shared_param_infos = [] shared_param_infos = []
params = [] params = []
param_numels = []
param_shapes = []
for module_name, m in self.named_modules(): for module_name, m in self.named_modules():
for n, p in m.named_parameters(recurse=False): for n, p in m.named_parameters(recurse=False):
if p is not None and (m, n) in self._param_set: if p is not None and (m, n) in self._param_set:
...@@ -86,9 +165,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -86,9 +165,7 @@ class FlattenParamsWrapper(nn.Module):
shared_param_memo[p] = (m, n) shared_param_memo[p] = (m, n)
param_infos.append((m, n)) param_infos.append((m, n))
param_full_infos.append((module_name, n)) param_full_infos.append((module_name, n))
params.append(p.detach()) params.append(p)
param_numels.append(p.numel())
param_shapes.append(p.size())
del shared_param_memo del shared_param_memo
assert len(set(p.dtype for p in params)) <= 1, "expects all parameters in module to have same dtype" assert len(set(p.dtype for p in params)) <= 1, "expects all parameters in module to have same dtype"
...@@ -97,8 +174,6 @@ class FlattenParamsWrapper(nn.Module): ...@@ -97,8 +174,6 @@ class FlattenParamsWrapper(nn.Module):
self._param_infos = tuple(param_infos) self._param_infos = tuple(param_infos)
self._param_full_infos = tuple(param_full_infos) self._param_full_infos = tuple(param_full_infos)
self._shared_param_infos = tuple(shared_param_infos) self._shared_param_infos = tuple(shared_param_infos)
self._param_numels = tuple(param_numels)
self._param_shapes = tuple(param_shapes)
return params return params
...@@ -109,9 +184,8 @@ class FlattenParamsWrapper(nn.Module): ...@@ -109,9 +184,8 @@ class FlattenParamsWrapper(nn.Module):
if not hasattr(self, "_param_infos"): if not hasattr(self, "_param_infos"):
assert flat_param is None assert flat_param is None
params = self._init_flatten_params() params = self._init_flatten_params()
flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0)) flat_param = FlatParameter(params)
self.param_numel = flat_param.numel() self.param_numel = flat_param.numel()
del params
# flatten # flatten
assert flat_param is not None assert flat_param is not None
...@@ -126,15 +200,14 @@ class FlattenParamsWrapper(nn.Module): ...@@ -126,15 +200,14 @@ class FlattenParamsWrapper(nn.Module):
# register the views as plain attributes # register the views as plain attributes
self._unflatten_params_as_views() self._unflatten_params_as_views()
def get_param_views(self, flat_param: Tensor) -> Generator: def _unflatten_params(self, external_data: Optional[Tensor] = None) -> None:
return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) """ Undo flattening and create separate parameters from the already flattened
self.flat_param or a user supplied external data.
def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None: """
assert self.is_flattened or flat_param is not None assert self.is_flattened or external_data is not None
self.is_flattened = False self.is_flattened = False
flat_param = flat_param if flat_param is not None else self.flat_param
ps = self.get_param_views(flat_param) ps = self.get_param_views([external_data])
for (m, n), p in zip(self._param_infos, ps): for (m, n), p in zip(self._param_infos, ps):
if hasattr(m, n): if hasattr(m, n):
delattr(m, n) delattr(m, n)
...@@ -147,22 +220,23 @@ class FlattenParamsWrapper(nn.Module): ...@@ -147,22 +220,23 @@ class FlattenParamsWrapper(nn.Module):
del self.flat_param del self.flat_param
def _unflatten_params_as_views(self) -> None: def _unflatten_params_as_views(self) -> None:
""" Unlike ``_unflatten_params``, this function unflatten into views and keep
self.flat_param unchanged.
"""
assert self.is_flattened assert self.is_flattened
ps = self.get_param_views(self.flat_param) ps = self.get_param_views()
for (m, n), p in zip(self._param_infos, ps): for (m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr setattr(m, n, p) # This will set as plain attr
for (m, n, shared_m, shared_n) in self._shared_param_infos: for (m, n, shared_m, shared_n) in self._shared_param_infos:
setattr(m, n, getattr(shared_m, shared_n)) setattr(m, n, getattr(shared_m, shared_n))
@contextmanager @contextmanager
def unflatten_params(self, recurse: bool = True, flat_param: Optional[Tensor] = None) -> Generator: def unflatten_params(self, flat_param: Optional[Tensor] = None) -> Generator:
""" """
Unflatten params. If the current instance is already unflattened, then Unflatten params. If the current instance is already unflattened, then
it will remain unflattened after the context manager exits. it will remain unflattened after the context manager exits.
Args: Args:
recurse (bool, Optional): recursively unflatten all nested instances
(default: True)
flat_param (Tensor, Optional): flat param to use for unflattening. flat_param (Tensor, Optional): flat param to use for unflattening.
If provided, the current instance must be in a flattened state If provided, the current instance must be in a flattened state
at the start of the context manager. The provided Tensor must be at the start of the context manager. The provided Tensor must be
...@@ -170,31 +244,20 @@ class FlattenParamsWrapper(nn.Module): ...@@ -170,31 +244,20 @@ class FlattenParamsWrapper(nn.Module):
manager. After the context manager exits, we will revert to manager. After the context manager exits, we will revert to
using ``self.flat_param`` (default: None). using ``self.flat_param`` (default: None).
""" """
if recurse: assert (
with ExitStack() as stack: flat_param is None or self.is_flattened
# unflatten any nested FlattenParamsWrapper instances ), "Unflattening with external flat_param requires current instance to be flattened"
for name, module in self.named_modules():
if isinstance(module, FlattenParamsWrapper):
is_self = name == ""
stack.enter_context(
module.unflatten_params(recurse=False, flat_param=flat_param if is_self else None)
)
# yield to the caller, with unflattened params in all nested instances
yield
# exiting from the ExitStack will re-flatten params
return
else:
assert (
flat_param is None or self.is_flattened
), "Unflattening with custom flat_param requires current instance to be flattened"
orig_flattened = self.is_flattened orig_flattened = self.is_flattened
if orig_flattened: if orig_flattened:
orig_flat_param = self.flat_param orig_flat_param = self.flat_param
self._unflatten_params(flat_param) self._unflatten_params(flat_param)
# Put yield in a try...finally in case the caller catches the exception and handles
# it. In that case, we need to properly handle the undoing of state.
try:
yield yield
finally:
if orig_flattened: if orig_flattened:
self._flatten_params(orig_flat_param) self._flatten_params(orig_flat_param)
...@@ -217,30 +280,31 @@ class FlattenParamsWrapper(nn.Module): ...@@ -217,30 +280,31 @@ class FlattenParamsWrapper(nn.Module):
# Since we have overloads above, we can use Any here. # Since we have overloads above, we can use Any here.
def state_dict(self, *args: Any, **kwargs: Any) -> Any: def state_dict(self, *args: Any, **kwargs: Any) -> Any:
"""Return the wrapped module's state_dict (unflattened).""" """Return the wrapped module's state_dict."""
if self.is_flattened and self._auto_unflatten_state_dict: if self.is_flattened and self._auto_unflatten_state_dict:
with self.unflatten_params(recurse=False): # Returns the original version.
with self.unflatten_params():
return super().state_dict(*args, **kwargs) return super().state_dict(*args, **kwargs)
else: else:
# Returns flattened version.
return super().state_dict(*args, **kwargs) return super().state_dict(*args, **kwargs)
def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Return the flattened state_dict.""" """Return the flattened state_dict."""
assert self.is_flattened assert self.is_flattened
with ExitStack() as stack: with self._no_auto_unflatten_state_dict():
# tell any nested FlattenParamsWrapper instances not to auto unflatten return self.state_dict(*args, **kwargs)
for module in self.modules(): # includes self
if isinstance(module, FlattenParamsWrapper):
stack.enter_context(module._no_auto_unflatten_state_dict())
state_dict = self.state_dict(*args, **kwargs)
return state_dict
@contextmanager @contextmanager
def _no_auto_unflatten_state_dict(self) -> Generator: def _no_auto_unflatten_state_dict(self) -> Generator:
backup = self._auto_unflatten_state_dict backup = self._auto_unflatten_state_dict
self._auto_unflatten_state_dict = False self._auto_unflatten_state_dict = False
yield # Put yield in a try...finally in case the caller catches the exception and handles
self._auto_unflatten_state_dict = backup # it. In that case, we need to properly handle the undoing of state.
try:
yield
finally:
self._auto_unflatten_state_dict = backup
def load_state_dict( def load_state_dict(
self, state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], strict: bool = True self, state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], strict: bool = True
...@@ -251,15 +315,31 @@ class FlattenParamsWrapper(nn.Module): ...@@ -251,15 +315,31 @@ class FlattenParamsWrapper(nn.Module):
""" """
# unflatten the module automatically if the state_dict is non-flat # unflatten the module automatically if the state_dict is non-flat
if self.is_flattened and "flat_param" not in state_dict: if self.is_flattened and "flat_param" not in state_dict:
with self.unflatten_params(recurse=True): # This object is flatten but state_dict is not. So we unflatten and load.
with self.unflatten_params():
return super().load_state_dict(state_dict, strict) return super().load_state_dict(state_dict, strict)
else: else:
# Otherwise, load as it.
return super().load_state_dict(state_dict, strict) return super().load_state_dict(state_dict, strict)
def forward(self, *inputs: Any, **kwinputs: Any) -> Any: def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
self._unflatten_params_as_views() self._unflatten_params_as_views()
return self.module(*inputs, **kwinputs) return self.module(*inputs, **kwinputs)
def get_param_views(
self, external_data_list: Optional[List[Optional[Tensor]]] = None
) -> Generator[Tensor, None, None]:
""" Used to get a generator over all views from a list of external data list. """
params = [self.flat_param] # For now, there is only a single flat param.
if external_data_list is None:
external_data_list = [None] * len(params)
gens = []
for p, data in zip(params, external_data_list):
gens.append(p.get_param_views(data))
return chain(*gens) # type: ignore
def _post_state_dict_hook( def _post_state_dict_hook(
module: nn.Module, state_dict: "OrderedDict[str, Tensor]", prefix: str, *args: Any module: nn.Module, state_dict: "OrderedDict[str, Tensor]", prefix: str, *args: Any
......
...@@ -674,8 +674,10 @@ def in_temporary_directory() -> Generator: ...@@ -674,8 +674,10 @@ def in_temporary_directory() -> Generator:
old_cwd = os.getcwd() old_cwd = os.getcwd()
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir) os.chdir(temp_dir)
yield temp_dir try:
os.chdir(old_cwd) yield temp_dir
finally:
os.chdir(old_cwd)
@contextlib.contextmanager @contextlib.contextmanager
...@@ -683,11 +685,12 @@ def temp_files_ctx(num: int) -> Generator: ...@@ -683,11 +685,12 @@ def temp_files_ctx(num: int) -> Generator:
""" A context to get tempfiles and ensure they are cleaned up. """ """ A context to get tempfiles and ensure they are cleaned up. """
files = [tempfile.mkstemp()[1] for _ in range(num)] files = [tempfile.mkstemp()[1] for _ in range(num)]
yield tuple(files) try:
yield tuple(files)
# temp files could have been removed, so we use rmf. finally:
for name in files: # temp files could have been removed, so we use rmf.
rmf(name) for name in files:
rmf(name)
def dump_all_tensors(rank: int) -> None: def dump_all_tensors(rank: int) -> None:
......
...@@ -121,6 +121,9 @@ class Tensor: ...@@ -121,6 +121,9 @@ class Tensor:
_has_been_cloned: Optional[bool] = ... _has_been_cloned: Optional[bool] = ...
#END #END
@staticmethod
def _make_subclass(cls: Any, data: Tensor, requires_grad: builtins.bool) -> Any: ...
def __init__(self, *args, **kwargs) -> None: ... def __init__(self, *args, **kwargs) -> None: ...
@property @property
......
...@@ -15,6 +15,8 @@ class Parameter(Tensor): ...@@ -15,6 +15,8 @@ class Parameter(Tensor):
_fp32_shard: Tensor _fp32_shard: Tensor
_fp16_shard: Optional[Tensor] _fp16_shard: Optional[Tensor]
def __new__(cls, data: Tensor, requires_grad: builtins.bool = True): ...
def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ... def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ...
... ...
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from fairscale.nn import FullyShardedDataParallel from fairscale.nn import FullyShardedDataParallel
...@@ -139,7 +140,6 @@ def _worker(gpu_id: int, sync_file: str, world_size: int, embedding_size: int, f ...@@ -139,7 +140,6 @@ def _worker(gpu_id: int, sync_file: str, world_size: int, embedding_size: int, f
@pytest.mark.parametrize("embedding_size", [128, 129]) @pytest.mark.parametrize("embedding_size", [128, 129])
@pytest.mark.parametrize("flatten_parameters", [True, False]) @pytest.mark.parametrize("flatten_parameters", [True, False])
def test_consolidation(embedding_size: int, flatten_parameters: bool): def test_consolidation(embedding_size: int, flatten_parameters: bool):
import torch.multiprocessing as mp
world_size = 2 world_size = 2
with in_temporary_directory(): with in_temporary_directory():
......
...@@ -19,8 +19,8 @@ from fairscale.utils.testing import objects_are_equal ...@@ -19,8 +19,8 @@ from fairscale.utils.testing import objects_are_equal
class TestFlattenParams(unittest.TestCase): class TestFlattenParams(unittest.TestCase):
def _get_module_init_fns(self): def _get_module_init_fns(self):
return [ return [
self._get_basic_linear_module,
self._get_shared_params_transformer, self._get_shared_params_transformer,
self._get_nested_flat_module,
] ]
def _get_transformer(self, seed=0): def _get_transformer(self, seed=0):
...@@ -47,13 +47,11 @@ class TestFlattenParams(unittest.TestCase): ...@@ -47,13 +47,11 @@ class TestFlattenParams(unittest.TestCase):
dec_layer.linear2.weight = enc_layer.linear2.weight dec_layer.linear2.weight = enc_layer.linear2.weight
return module return module
def _get_nested_flat_module(self, seed=0): def _get_basic_linear_module(self, seed=0):
module = torch.nn.Sequential( module = torch.nn.Sequential(
FlattenParamsWrapper( torch.nn.Sequential(torch.nn.Linear(4, 8), torch.nn.Linear(8, 8)),
torch.nn.Sequential(torch.nn.Linear(4, 8), FlattenParamsWrapper(torch.nn.Linear(8, 8))) torch.nn.Sequential(torch.nn.Linear(8, 16)),
), torch.nn.Linear(16, 4),
FlattenParamsWrapper(torch.nn.Sequential(FlattenParamsWrapper(torch.nn.Linear(8, 16)))),
FlattenParamsWrapper(torch.nn.Linear(16, 4)),
) )
def get_input(device, dtype): def get_input(device, dtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment