"vscode:/vscode.git/clone" did not exist on "ae63bd08a3c6d8e5a468c13ece2ec04e6745dcdf"
Unverified Commit 92f27daa authored by Paul Johnson's avatar Paul Johnson Committed by GitHub
Browse files

Improvements to ssd_offload to support pickling/unpickling SsdTensorHandle...

Improvements to ssd_offload to support pickling/unpickling SsdTensorHandle (and derived classes) (#964)

Verified that FSDP wrapped models using ssd_offload checkpoint save and restore correctly
parent 72f373c1
......@@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- FSDP: Add pickle/unpickle support for SsdTensorHandle (and derived classes),
verified that FSDP models w/ ssd_offload enabled can correctly call model.state_dict()
and model.load_state_dict(...) and thus successfully checkpoint and recover parameters
stored as SsdFlatParameters.
### Fixed
......
......@@ -9,10 +9,13 @@ from enum import Enum, auto
from functools import reduce
import io
import os
from typing import Any, Iterator, List, Optional, Sequence, Tuple, Type
import pickle
from types import TracebackType
from typing import IO, Any, BinaryIO, Iterator, List, Optional, Sequence, Tuple, Type, Union
import numpy as np
import torch
from torch.serialization import DEFAULT_PROTOCOL as DEFAULT_PROTOCOL
try:
from torch.utils._pytree import tree_map
......@@ -21,7 +24,7 @@ except ImportError:
pass
DEFAULT_CHUNK_SIZE = 1024 * 1024
DEFAULT_CHUNK_SIZE = 2048 * 2048
def _get_num_chunks(input_tensor: torch.Tensor, chunk_size_bytes: int = DEFAULT_CHUNK_SIZE) -> int:
......@@ -89,8 +92,12 @@ class SsdTensorHandle(torch.Tensor):
data into the tensor that is an attribute of the SsdTensorHandle object or write the tensor to file. At any
point in time the Tensor may be in memory or on disk.
Class Variables:
override_directory_path: This variable is used by CheckpointPathContextManager to modify the path to any
SsdTensorHandles that are saved to a checkpoint via pickling (e.g. torch.save)
Args:
shape Tuple[int, ...]: Shape of the tensor that is represented by the handle.
shape torch.Size: Shape of the tensor that is represented by the handle.
dtype: torch.dtype: Dtype of the tensor that is represented by the handle.
requires_grad: bool: Property of the tensor that is represeneted by the handle.
......@@ -98,14 +105,18 @@ class SsdTensorHandle(torch.Tensor):
A SSDTensorHandle object representing a Tensor.
"""
override_directory_path: Optional[str] = None
@staticmethod
def __new__(
cls: SsdTensorHandle, shape: Tuple[int, ...], dtype: torch.dtype, requires_grad: bool = False
cls: Type[SsdTensorHandle], shape: torch.Size, dtype: torch.dtype, requires_grad: bool = False
) -> SsdTensorHandle:
r = torch.Tensor._make_wrapper_subclass(cls, shape, dtype=dtype, requires_grad=requires_grad) # type: ignore
r = super(SsdTensorHandle, cls)._make_wrapper_subclass(cls, shape, dtype=dtype, requires_grad=requires_grad) # type: ignore
return r
def __init__(self, shape: Tuple[int, ...], dtype: torch.dtype, requires_grad: bool) -> None:
def __init__(self, shape: torch.Size, dtype: torch.dtype, requires_grad: bool) -> None:
self._unpickle_f: Optional[Union[BinaryIO, IO[bytes]]] = None
self._shape = shape
if len(shape) == 0:
self._numel = 0
......@@ -122,20 +133,18 @@ class SsdTensorHandle(torch.Tensor):
@classmethod
def from_file(
cls, shape: Tuple[int, ...], dtype: torch.dtype, filename: str, requires_grad: bool = False
cls, shape: torch.Size, dtype: torch.dtype, filename: str, offset: int = 0, requires_grad: bool = False
) -> SsdTensorHandle:
"""Returns a new SsdTensorHandle from a file."""
handle = cls(shape=shape, dtype=dtype, requires_grad=requires_grad)
handle.filename = filename
handle.storage_state = StorageState.ON_DISK
handle.point_to_file(filename, offset=offset)
return handle
@classmethod
def from_tensor(cls: Type[SsdTensorHandle], tensor: torch.Tensor) -> SsdTensorHandle:
"""Returns a new SsdTensorHandle from a tensor."""
handle = cls(shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad)
handle.tensor = tensor
handle.storage_state = StorageState.ON_CPU
handle.point_to_tensor(tensor)
return handle
def is_available(self) -> bool:
......@@ -152,12 +161,14 @@ class SsdTensorHandle(torch.Tensor):
def point_to_file(self, filename: str, offset: int) -> None:
self.set_file_params(filename, offset)
self.tensor = None
self.storage_state = StorageState.ON_DISK
def point_to_tensor(self, tensor: torch.Tensor) -> None:
assert self.tensor is None
assert self._shape == tensor.shape
assert self._dtype == tensor.dtype
self.tensor = tensor
self.storage_state = StorageState.ON_CPU
# if resizing a handle that is part of an ssd buffer, care must be taken that the new size
# doesn't conflict with adjacent handles!
......@@ -237,61 +248,167 @@ class SsdTensorHandle(torch.Tensor):
e.to_file()
return r
@classmethod
def __unpickle__(
cls: Type[SsdTensorHandle], shape: torch.Size, dtype: torch.dtype, requires_grad: bool, filename: str
) -> SsdTensorHandle:
result = cls(shape, dtype, requires_grad)
result.point_to_file(filename, 0)
result._unpickle_f = io.open(result.filename, "wb")
return result
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any, Any]:
byte_iter = None
filename = self.filename
if self.override_directory_path is not None:
head, tail = os.path.split(self.filename)
filename = os.path.join(self.override_directory_path, tail)
if self.is_available():
byte_iter = iter(TensorChunkingIterator(self.tensor))
else:
byte_iter = iter(
FileChunkingIterator(self.filename, expected_size_bytes=self.numel() * self.element_size())
)
return (
self.__unpickle__, # Callable
# Args to the callable above
(self._shape, self._dtype, self.requires_grad, filename),
None,
byte_iter,
)
def append(self, item: bytes) -> None:
assert self._unpickle_f
self._unpickle_f.write(item)
def extend(self, items: List[bytes]) -> None:
for i in items:
self.append(i)
class CheckpointPathContextManager:
"""
This Context allows the user to override the directory path when pickling an SsdTensorHandle Object.
It is needed because the filename which the SsdTensorHandle points to (and is used when unpickling)
is already baked into the pickled data.
Consider the following example code
ssd_handle = SsdTensorHandle.from_tensor(ref_tensor)
ssd_handle.set_file_params('/home/user/handle.bin', 0)
torch.save(ssd_handle, '/home/user/checkpoint.pkl')
ssd_handle += 1
ssd_handle.to_file()
ssd_handle2 = torch.load('/home/user/checkpoint.pkl')
print(f"handles are equal: {torch.equals(ssd_handle, ssd_handle2)}")
One would expect this to print False, however unintuitively it will print True.
ssd_handle.filename and ssd_handle2.filename are equal. This means that
when we execute torch.load, we read from the .pkl file and write the result into
/home/user/handle.bin, clobbering the updated result from `ssd_handle += 1`
We want to give the user the possibility of not clobbering the data using this
Context Manager.
ssd_handle = SsdTensorHandle.from_tensor(ref_tensor)
ssd_handle.set_file_params('/home/user/handle.bin', 0)
with CheckpointPathContextManager(override_path='/home/user/checkpoint_data/'):
torch.save(ssd_handle, '/home/user/checkpoint.pkl')
ssd_handle += 1
ssd_handle.to_file()
ssd_handle2 = torch.load('/home/user/checkpoint.pkl')
print(f"handles are equal: {torch.equals(ssd_handle, ssd_handle2)}")
This code results with ssd_handle.filename = '/home/user/handle.bin' and ssd_handle2.filename =
`/home/user/checkpoint_data/handle.bin'. Therefore the torch.load won't clobber ssd_handle, and
the printed result is False.
"""
def __init__(self, override_path: str) -> None:
self.old_path = SsdTensorHandle.override_directory_path
self.override_path = override_path
def __enter__(self) -> None:
SsdTensorHandle.override_directory_path = self.override_path
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
exec_traceback: Optional[TracebackType],
) -> None:
SsdTensorHandle.override_directory_path = self.old_path
# Classes supporting torch.save/load
class TorchSaver:
def __init__(self) -> None:
self.pickle_module = DisableMemoizationPicklerModule
def save(
self, obj: Any, f: Union[str, os.PathLike, BinaryIO, IO[bytes]], pickle_protocol: int = DEFAULT_PROTOCOL
) -> None:
torch.serialization.save(
obj, f, self.pickle_module, pickle_protocol=pickle_protocol, _use_new_zipfile_serialization=False
)
class SsdParameter(SsdTensorHandle, torch.nn.Parameter):
@classmethod
def from_tensor(cls: Type[SsdParameter], tensor: SsdTensorHandle) -> SsdParameter: # type: ignore
r = cls(tensor.shape, tensor.dtype, tensor.requires_grad)
r.point_to_tensor(tensor)
return r
@staticmethod
def __new__(
cls: Type[SsdParameter], shape: torch.Size, dtype: torch.dtype, requires_grad: bool = True
) -> SsdParameter:
r = super(SsdParameter, cls).__new__(cls, shape, dtype=dtype, requires_grad=requires_grad)
return r # type: ignore
def __init__(self, shape: torch.Size, dtype: torch.dtype, requires_grad: bool = True) -> None:
super(SsdParameter, self).__init__(shape, dtype, requires_grad)
class SsdFlatParameter(torch.nn.Parameter, SsdTensorHandle):
class SsdFlatParameter(SsdParameter):
"""A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed.
This class should eventually be moved to fairscale/nn/misc/flatten_params_wrapper.py
"""
def __new__(
cls, params: Sequence[torch.nn.Parameter], filename: str, requires_grad: bool = True
) -> "SsdFlatParameter":
cls: Type[SsdFlatParameter], shapes: Sequence[torch.Size], dtype: torch.dtype, requires_grad: bool = True
) -> SsdFlatParameter:
"""Make an object using the parent's __new__ function."""
# A empty of non-list input doesn't make sense.
if not isinstance(params, (list, tuple)) or len(params) == 0:
if not isinstance(shapes, (list, tuple)) or len(shapes) == 0:
raise ValueError("An non-empty list or tuple argument is needed")
# Normally, all items are Parameters. But during pickling, we will have a single
# Tensor as the input and later in __init__, the correct _param_numels and _param_shapes
# are set.
if not all(isinstance(p, (torch.nn.Parameter, torch.Tensor)) for p in params):
raise ValueError("List items need to be Parameter types")
size = sum([np.prod(s) for s in shapes])
r = super(SsdFlatParameter, cls).__new__(cls, torch.Size((size,)), dtype=dtype, requires_grad=requires_grad)
return r # type: ignore
# Flattening involves (1) making a tensor flat (i.e. single dimensional) and (2) making a module
# heirarchy flat (using a single tensor to replace a tree of tensors). Therefore,
# adding back nesting and heirarchy is counter-productive. If nesting is encountered
# in the future, the reasonable thing to do is likely for the top level SsdFlatParameter to
# absorb the nested one and keep the result flat, free from hierarchy.
if any(isinstance(p, SsdFlatParameter) for p in params):
raise ValueError("Nesting SsdFlatParameter is not supported")
dtype = params[0].dtype
size = sum(p.numel() for p in params)
r = SsdTensorHandle._make_wrapper_subclass(cls, (size,), dtype=dtype, requires_grad=requires_grad) # type: ignore
return r
def __init__(self, params: Sequence[torch.nn.Parameter], filename: str, requires_grad: bool = True):
def __init__(self, shapes: Sequence[torch.Size], dtype: torch.dtype, requires_grad: bool = True):
"""Initialize the _param_numels and _param_shapes lists."""
self._param_numels = [p.numel() for p in params]
self._param_shapes = shapes
self._param_numels = [np.prod(s) for s in shapes]
total_numels = sum(self._param_numels)
assert (
self.numel() <= total_numels
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
self._param_shapes = [p.size() for p in params]
# These are set by FPW class below, not by this class itself.
self._param_infos: List[Tuple[str, torch.nn.Module, str]] = []
self._shared_param_infos: List[Tuple[str, str, torch.nn.Module, str, torch.nn.Module, str]] = []
super(SsdFlatParameter, self).__init__(shape=(total_numels,), dtype=params[0].dtype, requires_grad=requires_grad) # type: ignore
tensor = torch.cat(
[p.detach().reshape(-1) if isinstance(p, torch.nn.Parameter) else p.reshape(-1) for p in params], 0
super(SsdFlatParameter, self).__init__(
shape=torch.Size((total_numels,)), dtype=dtype, requires_grad=requires_grad
)
tensor.requires_grad = requires_grad
self.set_file_params(filename, 0)
self.point_to_tensor(tensor)
def get_param_views(self, external_data: Optional[torch.Tensor] = None) -> Iterator[torch.Tensor]:
"""Return a generator of views that map to the original parameters."""
......@@ -301,7 +418,7 @@ class SsdFlatParameter(torch.nn.Parameter, SsdTensorHandle):
self._param_numels
), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}"
"""
if external_data:
if external_data is not None:
if external_data.numel() != sum(self._param_numels):
raise ValueError(
f"Incorrect numel of supplied data: got {external_data.numel()} but expected {sum(self._param_numels)}"
......@@ -310,24 +427,154 @@ class SsdFlatParameter(torch.nn.Parameter, SsdTensorHandle):
else:
return (t.view(s) for (t, s) in zip(self.split(self._param_numels), self._param_shapes))
def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]:
def metadata(self) -> Tuple[List[str], Sequence[torch.Size], List[int]]:
"""Return tuple of (names, shapes, numels) metadata for this flat parameter."""
names = [".".join([m, n]) if m else n for (m, _, n) in self._param_infos]
return names, self._param_shapes, self._param_numels
def __setstate__(self, state: Tuple[Any, Any, Any, Any]) -> None:
"""Use by pickle to set the internal states."""
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos) = state
assert self.numel() <= sum(
self._param_numels
), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}"
@classmethod
def from_tensors(
cls: Type[SsdFlatParameter], tensors: Sequence[torch.Tensor], direct_to_file: bool = False
) -> "SsdFlatParameter":
"""Returns a new SsdFlatParameter from a sequence of tensors."""
assert (
len(tensors) > 0
), "SsdFlatParameter.from_tensors must be called with at least one tensor in the tensors argument"
# Flattening involves (1) making a tensor flat (i.e. single dimensional) and (2) making a module
# heirarchy flat (using a single tensor to replace a tree of tensors). Therefore,
# adding back nesting and heirarchy is counter-productive. If nesting is encountered
# in the future, the reasonable thing to do is likely for the top level SsdFlatParameter to
# absorb the nested one and keep the result flat, free from hierarchy.
if any(isinstance(t, SsdFlatParameter) for t in tensors):
raise ValueError("Nesting SsdFlatParameter is not supported")
handle = cls(shapes=[t.size() for t in tensors], dtype=tensors[0].dtype, requires_grad=tensors[0].requires_grad)
if direct_to_file:
assert False, "direct_to_file not implemented yet"
pass
else:
tensor = torch.cat(
[t.detach().reshape(-1) if isinstance(t, torch.nn.Parameter) else t.reshape(-1) for t in tensors], 0
)
handle.point_to_tensor(tensor)
return handle
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any]:
"""Support pickling between ranks."""
@classmethod
def __unpickle_SFP__(
cls: Type[SsdFlatParameter],
shapes: Sequence[torch.Size],
dtype: torch.dtype,
requires_grad: bool,
filename: str,
) -> SsdFlatParameter:
result = cls(shapes, dtype, requires_grad)
result.point_to_file(filename, 0)
result._unpickle_f = io.open(result.filename, "wb")
return result
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any, Any]:
byte_iter = None
filename = self.filename
if self.override_directory_path is not None:
head, tail = os.path.split(self.filename)
filename = os.path.join(self.override_directory_path, tail)
if self.is_available():
byte_iter = iter(TensorChunkingIterator(self.tensor))
else:
byte_iter = iter(
FileChunkingIterator(self.filename, expected_size_bytes=self.numel() * self.element_size())
)
return (
SsdFlatParameter, # Callable
self.__unpickle_SFP__, # Callable
# Args to the callable above
([self.data], self.filename, self.requires_grad),
# Args to __setstate__
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos),
(self._param_shapes, self._dtype, self.requires_grad, filename),
None,
byte_iter,
)
class DisableMemoizationPicklerModule:
@classmethod
def Pickler(cls, data_buf: io.BytesIO, protocol: int) -> pickle.Pickler:
p = pickle.Pickler(data_buf, protocol)
p.fast = True
return p
@classmethod
def dump(cls, obj: Any, f: io.BytesIO, protocol: int) -> None:
pickle.dump(obj, f, protocol)
class TensorChunkingIterator:
"""
chunk_size_bytes determines how large each chunk that we break the tensor
into. It is important to consider limiting the size because by when
python unpickles an object, by default it will read up to 1000 list
elements at a time. So memory usage while unpickling will be on the
order of O(min(file_size, 1000 * chunk_size_bytes)).
"""
def __init__(self, tensor: torch.Tensor, chunk_size_bytes: int = DEFAULT_CHUNK_SIZE) -> None:
self.tensor = tensor
self.chunk_size_bytes = chunk_size_bytes
def __iter__(self) -> Iterator[bytes]:
self.num_chunks = _get_num_chunks(self.tensor, self.chunk_size_bytes)
self.num_chunks_read = 0
return self
def __next__(self) -> bytes:
if self.num_chunks_read >= self.num_chunks:
raise StopIteration
next_chunk = _tensor_to_bytes_chunks(
self.tensor, chunk_idx=self.num_chunks_read, chunk_size_bytes=self.chunk_size_bytes
)
self.num_chunks_read += 1
return next_chunk
class FileChunkingIterator:
"""
chunk_size_bytes determines how large each chunk that we break the file
into. It is important to consider limiting the size because by when
python unpickles an object, by default it will read up to 1000 list
elements at a time. So memory usage while unpickling will be on the
order of O(min(file_size, 1000 * chunk_size_bytes)).
"""
def __init__(
self, filename: str, expected_size_bytes: int = -1, chunk_size_bytes: int = DEFAULT_CHUNK_SIZE
) -> None:
self.filename = filename
self.file: Optional[Union[BinaryIO, IO[bytes]]] = None
self.chunk_size_bytes = chunk_size_bytes
self.expected_size_bytes = expected_size_bytes
def __iter__(self) -> Iterator[bytes]:
if self.expected_size_bytes != -1:
file_size = os.stat(self.filename).st_size
assert (
file_size == self.expected_size_bytes
), f"FileChunkingIterator Failed, expecting file to be of size: {self.expected_size_bytes} but got {file_size}"
self.file = io.open(self.filename, "rb", buffering=0)
self.num_chunks_read = 0
return self
def __next__(self) -> bytes:
assert self.file
next_chunk = self.file.read(self.chunk_size_bytes)
if len(next_chunk) == 0:
raise StopIteration
self.num_chunks_read += 1
return next_chunk
torch_saver = TorchSaver()
......@@ -223,7 +223,8 @@ class FlattenParamsWrapper(nn.Module):
if ssd_offload:
assert ssd_directory != ""
(handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param")
flat_param = SsdFlatParameter(params=params, filename=fname, requires_grad=params[0].requires_grad)
flat_param = SsdFlatParameter.from_tensors(tensors=params)
flat_param.set_file_params(fname, 0)
else:
flat_param = FlatParameter(params, params[0].requires_grad)
flat_param._param_infos = param_infos
......@@ -501,7 +502,7 @@ class FlattenParamsWrapper(nn.Module):
return chain(*gens)
def metadata(self, flat_param_idx: int) -> Tuple[List[str], List[torch.Size], List[int]]:
def metadata(self, flat_param_idx: int) -> Tuple[List[str], Sequence[torch.Size], List[int]]:
"""Return metadata for a flat param given its index in the flat_params list."""
return self.flat_params[flat_param_idx].metadata()
......
......@@ -7,6 +7,8 @@
Testing SsdFlatParameter and SsdTensorHandle modules.
"""
import filecmp
import os
import tempfile
import numpy as np
......@@ -109,7 +111,61 @@ def test_ssd_handle_train_simple():
assert torch.equal(ssd_handle.to_tensor(), orig_copy)
def test_ssd_flat_param_train_simple():
def test_torch_save_load_ssd_flat_param_on_disk():
_init()
orig_file = tempfile.NamedTemporaryFile(prefix="tensor")
checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt")
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
# TENSOR_SHAPE = (1024, 1024, 2048)
# use smaller shape for unit tests
TENSOR_SHAPE = (1024, 321)
ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)]
ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False)
ssd_handle.set_file_params(orig_file.name, 0)
ssd_handle.to_file()
ref_tensors = []
# after deleting ref_tensor, memory usage should be very low
# For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
so.torch_saver.save(ssd_handle, checkpoint_file.name)
# below line saves file to checkpoint_load_directory/orig_file.name
# Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE)
# 1000x because that's how many elements the python unpickler
# will buffer before passing to the SsdTensor
test_ssd_handle = torch.load(checkpoint_file)
head, tail = os.path.split(orig_file.name)
assert filecmp.cmp(orig_file.name, os.path.join(checkpoint_load_directory.name, tail), shallow=False)
def test_torch_save_load_ssd_flat_param_on_mem():
_init()
orig_file = tempfile.NamedTemporaryFile(prefix="tensor")
checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt")
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
# TENSOR_SHAPE = (1024, 1024, 2048)
# use smaller shape for unit tests
TENSOR_SHAPE = (1024, 321)
ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)]
ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False)
ssd_handle.set_file_params(orig_file.name, 0)
ref_tensors = []
# after deleting ref_tensor, memory usage should be very low
# For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
so.torch_saver.save(ssd_handle, checkpoint_file.name)
# below line saves file to checkpoint_load_directory/orig_file.name
# Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE)
# 1000x because that's how many elements the python unpickler
# will buffer before passing to the SsdTensor
test_ssd_handle = torch.load(checkpoint_file)
assert torch.equal(ssd_handle, test_ssd_handle)
def test_ssd_param_train_simple():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4))
......@@ -117,15 +173,18 @@ def test_ssd_flat_param_train_simple():
with torch.no_grad():
orig_copy = torch.empty_like(orig_tensor)
orig_copy.copy_(orig_tensor)
param = torch.nn.Parameter(orig_copy)
param = torch.nn.Parameter(orig_copy)
ssd_flat_param = so.SsdFlatParameter([param], f.name, True)
ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype)
ssd_param.point_to_tensor(orig_copy)
ssd_param.set_file_params(f.name, 0)
ssd_param.to_file(release_tensor_after_write=True)
assert torch.equal(list(ssd_flat_param.get_param_views())[0], orig_tensor)
optimizer_ssd = torch.optim.SGD([ssd_flat_param], lr=0.1)
assert torch.equal(ssd_param.to_tensor(), orig_tensor)
optimizer_ssd = torch.optim.SGD([ssd_param], lr=0.1)
optimizer_orig = torch.optim.SGD([param], lr=0.1)
y1 = ssd_flat_param + 1
y1 = ssd_param + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
optimizer_ssd.step()
......@@ -136,8 +195,8 @@ def test_ssd_flat_param_train_simple():
optimizer_orig.step()
# make sure we are using the file version not the cached tensor
ssd_flat_param.point_to_file(f.name, 0)
assert torch.equal(list(ssd_flat_param.get_param_views())[0], param)
ssd_param.point_to_file(f.name, 0)
assert torch.equal(ssd_param.to_tensor(), param)
def test_ssd_flat_parameter_basic():
......@@ -146,7 +205,8 @@ def test_ssd_flat_parameter_basic():
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter([refa_param, refb_param, refc_param], f.name, False)
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], False)
ssd_flat_param.set_file_params(f.name, 0)
param_views = list(ssd_flat_param.get_param_views())
......
......@@ -16,6 +16,7 @@ import torch
from torch import nn
import torch.distributed
import fairscale.experimental.nn.ssd_offload as so
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
from fairscale.utils import torch_version
......@@ -289,19 +290,53 @@ class TestSsdLoading(DistributedTest):
model = FullyShardedDataParallel(model, **config)
model_device = torch.device("cuda")
model.train()
optim = torch.optim.SGD(model.parameters(), lr=4, momentum=0.9)
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
checkpoint_file = tempfile.NamedTemporaryFile()
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
pre_checkpoint_last_output = None
post_checkpoint_last_output = None
ITERATIONS = 10
# Inputs always cuda regardless of move_grads_cpu, or model.device
with torch.cuda.amp.autocast(enabled=config.get("mixed_precision", False)):
for i in range(10):
for i in range(ITERATIONS):
optim.zero_grad()
input = model.get_input(torch.device("cuda"))
output = model(*input)
pre_checkpoint_last_output = output
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
optim.step()
if i == 0:
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
# so.torch_saver.save({"model": model.state_dict(), "optim": optim.state_dict()}, checkpoint_file.name)
torch.save({"model": model.state_dict()}, checkpoint_file.name)
# reset momentum just after checkpoint save
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
checkpoint = torch.load(checkpoint_file.name)
model.load_state_dict(checkpoint["model"])
# reset momentum just after checkpoint load
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# do more iterations after loading checkpoint
for i in range(ITERATIONS - 1):
optim.zero_grad()
input = model.get_input(torch.device("cuda"))
output = model(*input)
post_checkpoint_last_output = output
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
optim.step()
# Verify output of checkpoint load + run is equal to original output
assert torch.equal(pre_checkpoint_last_output, post_checkpoint_last_output)
if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE)
......@@ -445,6 +480,12 @@ def spawn_and_init(fn, args=None, **spawn_kwargs):
args = ()
run_fn = functools.partial(init_and_run, fn, args)
# Below 3 lines are to easily enable single-process debugging
# _, filename = tempfile.mkstemp()
# _, filename_rpc = tempfile.mkstemp()
# run_fn(0, 1, filename, filename_rpc)
spawn_for_all_world_sizes(run_fn, **spawn_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment