Unverified Commit 92f27daa authored by Paul Johnson's avatar Paul Johnson Committed by GitHub
Browse files

Improvements to ssd_offload to support pickling/unpickling SsdTensorHandle...

Improvements to ssd_offload to support pickling/unpickling SsdTensorHandle (and derived classes) (#964)

Verified that FSDP wrapped models using ssd_offload checkpoint save and restore correctly
parent 72f373c1
...@@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- FSDP: Add pickle/unpickle support for SsdTensorHandle (and derived classes),
verified that FSDP models w/ ssd_offload enabled can correctly call model.state_dict()
and model.load_state_dict(...) and thus successfully checkpoint and recover parameters
stored as SsdFlatParameters.
### Fixed ### Fixed
......
This diff is collapsed.
...@@ -223,7 +223,8 @@ class FlattenParamsWrapper(nn.Module): ...@@ -223,7 +223,8 @@ class FlattenParamsWrapper(nn.Module):
if ssd_offload: if ssd_offload:
assert ssd_directory != "" assert ssd_directory != ""
(handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param") (handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param")
flat_param = SsdFlatParameter(params=params, filename=fname, requires_grad=params[0].requires_grad) flat_param = SsdFlatParameter.from_tensors(tensors=params)
flat_param.set_file_params(fname, 0)
else: else:
flat_param = FlatParameter(params, params[0].requires_grad) flat_param = FlatParameter(params, params[0].requires_grad)
flat_param._param_infos = param_infos flat_param._param_infos = param_infos
...@@ -501,7 +502,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -501,7 +502,7 @@ class FlattenParamsWrapper(nn.Module):
return chain(*gens) return chain(*gens)
def metadata(self, flat_param_idx: int) -> Tuple[List[str], List[torch.Size], List[int]]: def metadata(self, flat_param_idx: int) -> Tuple[List[str], Sequence[torch.Size], List[int]]:
"""Return metadata for a flat param given its index in the flat_params list.""" """Return metadata for a flat param given its index in the flat_params list."""
return self.flat_params[flat_param_idx].metadata() return self.flat_params[flat_param_idx].metadata()
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
Testing SsdFlatParameter and SsdTensorHandle modules. Testing SsdFlatParameter and SsdTensorHandle modules.
""" """
import filecmp
import os
import tempfile import tempfile
import numpy as np import numpy as np
...@@ -109,7 +111,61 @@ def test_ssd_handle_train_simple(): ...@@ -109,7 +111,61 @@ def test_ssd_handle_train_simple():
assert torch.equal(ssd_handle.to_tensor(), orig_copy) assert torch.equal(ssd_handle.to_tensor(), orig_copy)
def test_ssd_flat_param_train_simple(): def test_torch_save_load_ssd_flat_param_on_disk():
_init()
orig_file = tempfile.NamedTemporaryFile(prefix="tensor")
checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt")
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
# TENSOR_SHAPE = (1024, 1024, 2048)
# use smaller shape for unit tests
TENSOR_SHAPE = (1024, 321)
ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)]
ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False)
ssd_handle.set_file_params(orig_file.name, 0)
ssd_handle.to_file()
ref_tensors = []
# after deleting ref_tensor, memory usage should be very low
# For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
so.torch_saver.save(ssd_handle, checkpoint_file.name)
# below line saves file to checkpoint_load_directory/orig_file.name
# Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE)
# 1000x because that's how many elements the python unpickler
# will buffer before passing to the SsdTensor
test_ssd_handle = torch.load(checkpoint_file)
head, tail = os.path.split(orig_file.name)
assert filecmp.cmp(orig_file.name, os.path.join(checkpoint_load_directory.name, tail), shallow=False)
def test_torch_save_load_ssd_flat_param_on_mem():
_init()
orig_file = tempfile.NamedTemporaryFile(prefix="tensor")
checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt")
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
# TENSOR_SHAPE = (1024, 1024, 2048)
# use smaller shape for unit tests
TENSOR_SHAPE = (1024, 321)
ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)]
ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False)
ssd_handle.set_file_params(orig_file.name, 0)
ref_tensors = []
# after deleting ref_tensor, memory usage should be very low
# For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
so.torch_saver.save(ssd_handle, checkpoint_file.name)
# below line saves file to checkpoint_load_directory/orig_file.name
# Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE)
# 1000x because that's how many elements the python unpickler
# will buffer before passing to the SsdTensor
test_ssd_handle = torch.load(checkpoint_file)
assert torch.equal(ssd_handle, test_ssd_handle)
def test_ssd_param_train_simple():
_init() _init()
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4)) orig_tensor = torch.randn((4, 4))
...@@ -117,15 +173,18 @@ def test_ssd_flat_param_train_simple(): ...@@ -117,15 +173,18 @@ def test_ssd_flat_param_train_simple():
with torch.no_grad(): with torch.no_grad():
orig_copy = torch.empty_like(orig_tensor) orig_copy = torch.empty_like(orig_tensor)
orig_copy.copy_(orig_tensor) orig_copy.copy_(orig_tensor)
param = torch.nn.Parameter(orig_copy) param = torch.nn.Parameter(orig_copy)
ssd_flat_param = so.SsdFlatParameter([param], f.name, True) ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype)
ssd_param.point_to_tensor(orig_copy)
ssd_param.set_file_params(f.name, 0)
ssd_param.to_file(release_tensor_after_write=True)
assert torch.equal(list(ssd_flat_param.get_param_views())[0], orig_tensor) assert torch.equal(ssd_param.to_tensor(), orig_tensor)
optimizer_ssd = torch.optim.SGD([ssd_flat_param], lr=0.1) optimizer_ssd = torch.optim.SGD([ssd_param], lr=0.1)
optimizer_orig = torch.optim.SGD([param], lr=0.1) optimizer_orig = torch.optim.SGD([param], lr=0.1)
y1 = ssd_flat_param + 1 y1 = ssd_param + 1
optimizer_ssd.zero_grad() optimizer_ssd.zero_grad()
y1.sum().backward() y1.sum().backward()
optimizer_ssd.step() optimizer_ssd.step()
...@@ -136,8 +195,8 @@ def test_ssd_flat_param_train_simple(): ...@@ -136,8 +195,8 @@ def test_ssd_flat_param_train_simple():
optimizer_orig.step() optimizer_orig.step()
# make sure we are using the file version not the cached tensor # make sure we are using the file version not the cached tensor
ssd_flat_param.point_to_file(f.name, 0) ssd_param.point_to_file(f.name, 0)
assert torch.equal(list(ssd_flat_param.get_param_views())[0], param) assert torch.equal(ssd_param.to_tensor(), param)
def test_ssd_flat_parameter_basic(): def test_ssd_flat_parameter_basic():
...@@ -146,7 +205,8 @@ def test_ssd_flat_parameter_basic(): ...@@ -146,7 +205,8 @@ def test_ssd_flat_parameter_basic():
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32)) refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter([refa_param, refb_param, refc_param], f.name, False) ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], False)
ssd_flat_param.set_file_params(f.name, 0)
param_views = list(ssd_flat_param.get_param_views()) param_views = list(ssd_flat_param.get_param_views())
......
...@@ -16,6 +16,7 @@ import torch ...@@ -16,6 +16,7 @@ import torch
from torch import nn from torch import nn
import torch.distributed import torch.distributed
import fairscale.experimental.nn.ssd_offload as so
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
from fairscale.utils import torch_version from fairscale.utils import torch_version
...@@ -289,19 +290,53 @@ class TestSsdLoading(DistributedTest): ...@@ -289,19 +290,53 @@ class TestSsdLoading(DistributedTest):
model = FullyShardedDataParallel(model, **config) model = FullyShardedDataParallel(model, **config)
model_device = torch.device("cuda") model_device = torch.device("cuda")
model.train() model.train()
optim = torch.optim.SGD(model.parameters(), lr=4, momentum=0.9) optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
checkpoint_file = tempfile.NamedTemporaryFile()
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
pre_checkpoint_last_output = None
post_checkpoint_last_output = None
ITERATIONS = 10
# Inputs always cuda regardless of move_grads_cpu, or model.device # Inputs always cuda regardless of move_grads_cpu, or model.device
with torch.cuda.amp.autocast(enabled=config.get("mixed_precision", False)): with torch.cuda.amp.autocast(enabled=config.get("mixed_precision", False)):
for i in range(10): for i in range(ITERATIONS):
optim.zero_grad()
input = model.get_input(torch.device("cuda"))
output = model(*input)
pre_checkpoint_last_output = output
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
optim.step()
if i == 0:
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
# so.torch_saver.save({"model": model.state_dict(), "optim": optim.state_dict()}, checkpoint_file.name)
torch.save({"model": model.state_dict()}, checkpoint_file.name)
# reset momentum just after checkpoint save
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
checkpoint = torch.load(checkpoint_file.name)
model.load_state_dict(checkpoint["model"])
# reset momentum just after checkpoint load
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# do more iterations after loading checkpoint
for i in range(ITERATIONS - 1):
optim.zero_grad() optim.zero_grad()
input = model.get_input(torch.device("cuda")) input = model.get_input(torch.device("cuda"))
output = model(*input) output = model(*input)
post_checkpoint_last_output = output
loss = model.module.get_loss(input, output).to(model_device) loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32 assert loss.dtype == torch.float32
model.module.run_backward(loss) model.module.run_backward(loss)
optim.step() optim.step()
# Verify output of checkpoint load + run is equal to original output
assert torch.equal(pre_checkpoint_last_output, post_checkpoint_last_output)
if isinstance(model, FullyShardedDataParallel): if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE) model.assert_state(TrainingState.IDLE)
...@@ -445,6 +480,12 @@ def spawn_and_init(fn, args=None, **spawn_kwargs): ...@@ -445,6 +480,12 @@ def spawn_and_init(fn, args=None, **spawn_kwargs):
args = () args = ()
run_fn = functools.partial(init_and_run, fn, args) run_fn = functools.partial(init_and_run, fn, args)
# Below 3 lines are to easily enable single-process debugging
# _, filename = tempfile.mkstemp()
# _, filename_rpc = tempfile.mkstemp()
# run_fn(0, 1, filename, filename_rpc)
spawn_for_all_world_sizes(run_fn, **spawn_kwargs) spawn_for_all_world_sizes(run_fn, **spawn_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment