"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "8b2cd85d204c5fbe6b20f1fa6f6e46ff4898331d"
Unverified Commit ce5860ea authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] AdaScale: Gradient Accumulation and Add PyTest unit tests (#202)

* added AdaScale to README

* [adascale] added gradient accumulation

- added gradient accumulation
- tested with cifar full trainings with different value of accumulation
and verified the full accuracy is obtained
- also removed the patch optimize flag until we need it

* [adascale] adding pytest

- added basic and ddp tests and grad_accum
- closes #195

* added changelog

* added ddp grad_accum test

* moved ddp and non-ddp tests into separate files

* added checkpoint test

* more doc

* addressed Mike's comments
parent 867cc2df
...@@ -5,6 +5,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ...@@ -5,6 +5,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [next rel] - TBD
### Added
- AdaScale: Added gradient accumulation feature
### Fixed
- tbd
## [0.1.1] - 2020-12-01 ## [0.1.1] - 2020-12-01
### Fixed ### Fixed
- make sure pip package includes header files (#221) - make sure pip package includes header files (#221)
......
...@@ -12,6 +12,7 @@ fairscale supports: ...@@ -12,6 +12,7 @@ fairscale supports:
* tensor parallelism (fairscale.nn.model_parallel) * tensor parallelism (fairscale.nn.model_parallel)
* Optimization: * Optimization:
* optimizer state sharding (fairscale.optim.oss) * optimizer state sharding (fairscale.optim.oss)
* AdaScale SGD (from fairscale.optim import AdaScale)
## Requirements ## Requirements
...@@ -103,7 +104,18 @@ if __name__ == "__main__": ...@@ -103,7 +104,18 @@ if __name__ == "__main__":
) )
``` ```
### AdaScale SGD
AdaScale can be used to wrap a SGD optimizer and to be used in DDP (Distributed Data Parallel)
training or non-DDP with gradient accumulation. The benefit is to re-use the same LR
schedule from a baseline batch size when effective batch size is bigger.
Primary goal is to allow scaling to bigger batch sizes without losing model accuracy.
At a high level, we want ML researchers to:
* go parallel more easily (i.e. reuse the same LR schedule)
* not worrying about lossing accuracy
* get same (or higher) GPU efficiency (fewer steps, less networking, etc.)
# Testing # Testing
......
...@@ -35,8 +35,9 @@ import functools ...@@ -35,8 +35,9 @@ import functools
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import numpy as np import numpy as np
import torch
from torch.autograd import Variable from torch.autograd import Variable
import torch.distributed import torch.distributed as dist
class AdaScale(object): class AdaScale(object):
...@@ -49,7 +50,7 @@ class AdaScale(object): ...@@ -49,7 +50,7 @@ class AdaScale(object):
.. code-block:: python .. code-block:: python
optim = torch.optim.SGD(model, lr=0.001) optim = torch.optim.SGD(model.parameters(), lr=0.001)
model = DistributedDataParallel(model) model = DistributedDataParallel(model)
adascale = AdaScale(optim) adascale = AdaScale(optim)
...@@ -65,14 +66,16 @@ class AdaScale(object): ...@@ -65,14 +66,16 @@ class AdaScale(object):
Optimizer to apply AdaScale to. Optimizer to apply AdaScale to.
world_size (int): world_size (int):
Number of world_size for distributed training. If Number of world_size for distributed training. If
None, defaults to ``torch.distributed.get_world_size()``. None, defaults to ``dist.get_world_size()``.
scale (float): scale (float):
Scaling factor of the batch size, e.g. using a 10x Scaling factor of the batch size, e.g. using a 10x
larger batch size (summed across all world_size) means a scale of larger batch size (summed across all world_size) means a scale of
10. If None, defaults to ``world_size``. 10. If None, defaults to ``world_size``.
patch_optimizer (bool): smoothing (float):
If True, monkey-patches the ``step`` method of Smoothing factor between batches. Default value: 0.9999
the optimizer with the AdaScale's ``step`` method. num_gradients_to_accumulate (int):
Number of passes that we accumulate gradients locally.
Default to 1, which does not accumulate gradients.
""" """
def __init__( def __init__(
...@@ -81,16 +84,22 @@ class AdaScale(object): ...@@ -81,16 +84,22 @@ class AdaScale(object):
world_size: Optional[int] = None, world_size: Optional[int] = None,
scale: Optional[float] = None, scale: Optional[float] = None,
smoothing: float = 0.999, smoothing: float = 0.999,
patch_optimizer: bool = False, num_gradients_to_accumulate: int = 1,
): ):
self._optimizer = optimizer self._optimizer = optimizer
self._optimizer_step = optimizer.step
self._local_grad_sqr: Optional[torch.Tensor] = None self._local_grad_sqr: Optional[torch.Tensor] = None
self._world_size: int = (world_size if world_size is not None else torch.distributed.get_world_size()) self._world_size: int = (
world_size if world_size is not None else dist.get_world_size() if dist.is_initialized() else 1
)
self._smoothing = smoothing
self._num_backward_calls = 0
self._num_grads_to_accum = num_gradients_to_accumulate
if self._world_size <= 1: if self._world_size * self._num_grads_to_accum <= 1:
raise RuntimeError("AdaScale does not support a single worker.") # gain will be NaN since we will be dividing by zero in paper's B.3 where (S-1) == 0.
raise RuntimeError("AdaScale does not support a single worker without grad accumulation.")
# Per-param-group sqr & var states (sigma^2 & mu^2 in the paper).
self._optimizer.state.setdefault( self._optimizer.state.setdefault(
"adascale", "adascale",
{ {
...@@ -99,19 +108,19 @@ class AdaScale(object): ...@@ -99,19 +108,19 @@ class AdaScale(object):
}, },
) )
self.set_scale(self._world_size if scale is None else scale) self.set_scale(self._world_size * self._num_grads_to_accum if scale is None else scale)
# Register the gradient hooks. Note, don't assume every param will generate
# a gradient (i.e. triggering the hook) in every backward pass.
for idx, param_group in enumerate(self._optimizer.param_groups): for idx, param_group in enumerate(self._optimizer.param_groups):
for param in param_group["params"]: for param in param_group["params"]:
param.register_hook(functools.partial(self._backward_hook, idx)) param.register_hook(functools.partial(self._backward_hook, idx))
if patch_optimizer:
self.patch_optimizer()
self._smoothing = smoothing
@property @property
def state(self) -> Dict[str, np.ndarray]: def state(self) -> Dict[str, np.ndarray]:
"""
Return the states of AdaScale.
"""
return self._optimizer.state["adascale"] return self._optimizer.state["adascale"]
@property @property
...@@ -121,6 +130,10 @@ class AdaScale(object): ...@@ -121,6 +130,10 @@ class AdaScale(object):
batch size when training with a single worker. For example, if the batch size when training with a single worker. For example, if the
baseline batch size is 32, but using a scaled-up batch size of 80, then baseline batch size is 32, but using a scaled-up batch size of 80, then
then the scaling factor is 2.5. then the scaling factor is 2.5.
Returns:
(float):
The current scaling factor.
""" """
return self._scale return self._scale
...@@ -136,46 +149,64 @@ class AdaScale(object): ...@@ -136,46 +149,64 @@ class AdaScale(object):
""" """
self._scale = scale self._scale = scale
def grad_sqr_avg(self) -> float: def grad_sqr_avg(self, pg_idx: Optional[int] = None) -> float:
""" """
Current estimate of the squared l2-norm of the true gradient (sigma Current estimate of the squared l2-norm of the true gradient
squared in the AdaScale paper). (sigma squared in the AdaScale paper).
Returns Args:
pg_idx (Optional[int]):
Optional index for a parameter group.
Returns:
(float): (float):
Estimate of squared l2-norm. Estimate of squared l2-norm.
""" """
return np.sum(self.state["grad_sqr_avg"]) if pg_idx is not None:
return self.state["grad_sqr_avg"][pg_idx]
else:
return np.sum(self.state["grad_sqr_avg"])
def grad_var_avg(self) -> float: def grad_var_avg(self, pg_idx: Optional[int] = None) -> float:
""" """
Current estimate of the trace of the covariance of the true gradient Current estimate of the trace of the covariance of the true gradient
(mu squared in the AdaScale paper). (mu squared in the AdaScale paper).
Returns Args:
pg_idx (Optional[int]):
Optional index for a parameter group.
Returns:
(float): (float):
Estimate of trace of the covariance. Estimate of trace of the covariance.
""" """
return np.sum(self.state["grad_var_avg"]) if pg_idx is not None:
return self.state["grad_var_avg"][pg_idx]
else:
return np.sum(self.state["grad_var_avg"])
def gain(self, scale: Optional[float] = None) -> float: def gain(self, scale: Optional[float] = None, pg_idx: Optional[int] = None) -> float:
""" """
Current estimate of the AdaScale gain ratio (r_t). Current estimate of the AdaScale gain ratio (r_t in the paper).
Args: Args:
scale (float): scale (float):
The batch size scale to estimate the gain ratio for. Optional batch size scale to estimate the gain ratio for.
pg_idx (int):
Optional index of a parameter group.
Returns Returns:
:(float): (float):
Estimate of gain ratio. Estimate of gain ratio.
""" """
scale = self._scale if scale is None else scale scale = self._scale if scale is None else scale
var = self.grad_var_avg() var = self.grad_var_avg(pg_idx)
sqr = self.grad_sqr_avg() sqr = self.grad_sqr_avg(pg_idx)
return (var + sqr) / (var / scale + sqr) return (var + sqr) / (var / scale + sqr)
def _update_avg(self, name: str, value: float, factor: float) -> None: def _update_avg(self, name: str, value: torch.Tensor, factor: float) -> None:
# This function computes and stores the moving average of a vector
# using a smoothing factor.
biased = self.state.get(name + "_biased", 0.0) biased = self.state.get(name + "_biased", 0.0)
unbias = self.state.get(name + "_unbias", 0.0) unbias = self.state.get(name + "_unbias", 0.0)
biased = factor * biased + (1.0 - factor) * value biased = factor * biased + (1.0 - factor) * value
...@@ -184,12 +215,18 @@ class AdaScale(object): ...@@ -184,12 +215,18 @@ class AdaScale(object):
self.state[name + "_unbias"] = unbias self.state[name + "_unbias"] = unbias
self.state[name] = biased / unbias self.state[name] = biased / unbias
def _backward_hook(self, idx: int, grad: torch.Tensor) -> None: def _backward_hook(self, pg_idx: int, grad: torch.Tensor) -> None:
# This method should be invoked once for each parameter during the # This method should be invoked once for each parameter during the
# backward pass, before gradients are synchronized between world_size. # backward pass, before gradients are synchronized between world_size.
# Store the local gradient square sums in a vector.
if self._local_grad_sqr is None: if self._local_grad_sqr is None:
self._local_grad_sqr = torch.zeros(len(self._optimizer.param_groups), device=grad.device) self._local_grad_sqr = torch.zeros(len(self._optimizer.param_groups), device=grad.device)
self._local_grad_sqr[idx] += grad.pow(2).sum() self._local_grad_sqr[pg_idx] += grad.pow(2).sum()
# Now, ensure we queue a callback at the end of the callback queue.
# This will fire after all gradient callbacks are done (esp. those
# queued by DDP.
self._final_callback_queued = False self._final_callback_queued = False
Variable._execution_engine.queue_callback(self._queue_callback) Variable._execution_engine.queue_callback(self._queue_callback)
...@@ -208,26 +245,59 @@ class AdaScale(object): ...@@ -208,26 +245,59 @@ class AdaScale(object):
def _final_callback(self) -> None: def _final_callback(self) -> None:
# This method should be invoked once for each backward pass, after # This method should be invoked once for each backward pass, after
# gradients have been synchronized between each worker. # gradients have been synchronized between each worker, unless we
# are in gradient accumulation mode, where grads are not all_reduced
# between the GPUs.
self._final_callback_queued = False self._final_callback_queued = False
assert isinstance(self._local_grad_sqr, torch.Tensor) assert isinstance(self._local_grad_sqr, torch.Tensor)
# self._local_grad_sqr is FP32, sum then div shouldn't overflow. # Keep track of number of backward calls for gradient accumulation.
torch.distributed.all_reduce(self._local_grad_sqr) # SUM self._num_backward_calls += 1
self._local_grad_sqr.div_(self._world_size)
local_grad_sqr = self._local_grad_sqr.cpu().numpy() # TODO (min, mike): We need to have a way to check that training loop & DDP
# is doing the right thing where the gradient is reduced
# in this backward pass.
# Longer term, we may compute the gain and then inform
# the training loop when it is a good time to step().
if self._num_backward_calls % self._num_grads_to_accum != 0:
return
# Since self._local_grad_sqr is FP32, sum shouldn't overflow.
# This vector has length of # of param_groups, so it is small, but we
# use async to hide the all_reduce latency, esp when # of nodes is large.
work = None
if self._world_size > 1:
work = dist.all_reduce(self._local_grad_sqr, async_op=True) # SUM
# Compute the sums of squares for reduced gradients.
# Divide by _num_grads_to_accum since the gradients are accumulated.
#
# Note: we are mutating the gradients here!!!
total_grad_sqr = np.array( total_grad_sqr = np.array(
[sum(param.grad.pow(2).sum().item() for param in group["params"]) for group in self._optimizer.param_groups] [
sum(param.grad.div_(self._num_grads_to_accum).pow(2).sum().item() for param in group["params"])
for group in self._optimizer.param_groups
]
) )
grad_sqr = (self._world_size * total_grad_sqr - local_grad_sqr) / (self._world_size - 1)
grad_var = (local_grad_sqr - total_grad_sqr) * self._scale / (self._world_size - 1) # Wait for all_reduce to be done and move it to cpu & np.
grad_sqr = np.maximum(grad_sqr, 0.0) if work:
work.wait()
local_grad_sqr = self._local_grad_sqr.cpu().numpy()
self._local_grad_sqr = None
# See appendix B.3 of the paper.
#
# local_grad_sqr is \sigma_{i=1}^{S}\norm{g_t_i}
# total_grad_sqr is \norm{g_t}
S = self._scale
grad_var = local_grad_sqr / (S - 1) - total_grad_sqr * S / (S - 1)
grad_sqr = total_grad_sqr - grad_var / S
grad_var = np.maximum(grad_var, 1e-6) grad_var = np.maximum(grad_var, 1e-6)
theta = self._smoothing ** self._scale grad_sqr = np.maximum(grad_sqr, 0.0)
theta = self._smoothing ** S
self._update_avg("grad_sqr_avg", grad_sqr, theta) self._update_avg("grad_sqr_avg", grad_sqr, theta)
self._update_avg("grad_var_avg", grad_var, theta) self._update_avg("grad_var_avg", grad_var, theta)
self._local_grad_sqr = None
def step(self, *args: Any, **kwargs: Any) -> Optional[float]: def step(self, *args: Any, **kwargs: Any) -> Optional[float]:
""" """
...@@ -235,36 +305,38 @@ class AdaScale(object): ...@@ -235,36 +305,38 @@ class AdaScale(object):
``optimizer.step(*args, **kwargs)`` with a scaled learning rate. ``optimizer.step(*args, **kwargs)`` with a scaled learning rate.
Args: Args:
args: args (Any):
Positional arguments passed to ``optimizer.step``. Positional arguments passed to ``optimizer.step``.
kwargs: kwargs (Any):
Keyword arguments passed to ``optimizer.step``. Keyword arguments passed to ``optimizer.step``.
Returns: Returns:
(Tensor): (Tensor):
loss if a closure is passed to the optimizer to reevaluate the model. The loss tensor if a closure if used to re-evaluate the model.
""" """
initial_lr = [pg["lr"] for pg in self._optimizer.param_groups] # Set original LR and set new LR.
original_lr = []
for idx, param_group in enumerate(self._optimizer.param_groups): for idx, param_group in enumerate(self._optimizer.param_groups):
grad_sqr = float(self.state["grad_sqr_avg"][idx]) original_lr.append(param_group["lr"])
grad_var = float(self.state["grad_var_avg"][idx]) param_group["lr"] = self.gain(pg_idx=idx) * param_group["lr"]
gain = (grad_var + grad_sqr) / (grad_var / self._scale + grad_sqr)
param_group["lr"] = gain * param_group["lr"]
res = self._optimizer_step(*args, **kwargs)
for lr, param_group in zip(initial_lr, self._optimizer.param_groups):
param_group["lr"] = lr
return res
def patch_optimizer(self) -> None: # Step it.
""" res = self._optimizer.step(*args, **kwargs)
Monkey-patch the optimizer's step function with :meth:`AdaScale.step`.
"""
@functools.wraps(self._optimizer.step) # Restore the original LR.
def wrapper(*args: Any, **kwargs: Any) -> Optional[float]: for lr, param_group in zip(original_lr, self._optimizer.param_groups):
return self.step(*args, **kwargs) param_group["lr"] = lr
setattr(self._optimizer, "step", wrapper) return res
def zero_grad(self) -> None: def zero_grad(self) -> None:
"""Proxy function to optimizer""" """Proxy function to optimizer, because some training loops need this."""
self._optimizer.zero_grad() return self._optimizer.zero_grad()
def state_dict(self) -> Dict:
"""Proxy function to optimizer, checkpointing needs this."""
return self._optimizer.state_dict()
def load_state_dict(self, data: Dict) -> None:
"""Proxy function to optimizer, checkpointing needs this."""
return self._optimizer.load_state_dict(data)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test AdaScale with DDP. """
import tempfile
import numpy as np
import pytest
import torch
from torch import Tensor
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
from fairscale.optim import AdaScale
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs are required")
def _dist_init(rank, world_size, tempfile_name, backend):
url = "file://" + tempfile_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def _test_basic_func(rank, world_size, tempfile_name):
_dist_init(rank, world_size, tempfile_name, backend="nccl") # Covers nccl
model = Linear(2, 2, bias=False)
model.to("cuda")
model = DDP(model, device_ids=[rank])
optim = AdaScale(SGD(model.parameters(), lr=0.1))
# iter 1
in_data = Tensor([0.0, 0.0])
in_data[rank] = 1.0
in_data = in_data.cuda()
out = model(in_data)
out.sum().backward()
assert np.allclose(optim.gain(), 2.0), optim.gain()
optim.step()
optim.zero_grad()
dist.destroy_process_group()
@skip_if_single_gpu
def test_basic():
"""Test adascale with DDP without gradient accumulation"""
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(_test_basic_func, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def _test_grad_accum_func(rank, world_size, tempfile_name):
_dist_init(rank, world_size, tempfile_name, backend="gloo") # Covers gloo
model = Linear(4, 2, bias=False)
model.to("cuda")
model = DDP(model, device_ids=[rank])
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=2)
with model.no_sync():
# iter 1, input vectors are pointing dim0 and dim1
in_data = Tensor([0.0] * 4)
in_data[rank] = 1.0
in_data = in_data.cuda()
out = model(in_data)
out.sum().backward()
# iter 2, input vectors are pointing dim2 and dim3
in_data = Tensor([0.0] * 4)
in_data[rank + 2] = 1.0
in_data = in_data.cuda()
out = model(in_data)
out.sum().backward()
# since all inputs are orthogonal, the gain should be exactly 4.0.
assert np.allclose(optim.gain(), 4.0), optim.gain()
optim.step()
optim.zero_grad()
dist.destroy_process_group()
@skip_if_single_gpu
def test_grad_accum():
"""Test adascale with DDP + gradient accumulation using ddp.no_sync()"""
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(_test_grad_accum_func, args=(world_size, temp_file_name), nprocs=world_size, join=True)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test AdaScale with a single node (1 CPU or 1 GPU). """
import tempfile
import numpy as np
import pytest
import torch
from torch import Tensor
from torch.nn import Linear
from torch.optim import SGD
from fairscale.optim import AdaScale
skip_if_no_gpu = pytest.mark.skipif(torch.cuda.device_count() < 1, reason="1 GPU is required")
def test_basic_cpu():
"""Test single batch behavior on CPU"""
model = Linear(2, 2, bias=False)
try:
optim = AdaScale(SGD(model.parameters(), lr=0.1))
except RuntimeError:
return
assert False, "Single batch AdaScale should not be suppported"
def test_loss_accum_cpu():
"""Test the loss accumulation behavior on CPU
Loss accumulation is NOT SUPPORTED. This test shows that it does not work.
"""
model = Linear(2, 2, bias=False)
# num_gradients_to_accumulate value doesn't matter in this negative test.
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=123)
# data 1
in_data = Tensor([0.0, 1.0])
loss = model(in_data).sum()
# data 2
in_data = Tensor([1.0, 0.0])
loss += model(in_data).sum()
# data 3
in_data = Tensor([1.0, 2.0])
loss += model(in_data).sum()
# backward, but gradient is only produced once by the autograd engine.
loss.backward()
# therefore, the gain will always be 1, which renders adascale as noop.
optim.step()
assert np.allclose(optim.gain(), 1.0), optim.gain()
def test_grad_accum_cpu(cpu=True):
"""Test the basic functionality on CPU with gradient accumulation without DDP"""
model = Linear(2, 2, bias=False)
if not cpu:
model = model.cuda()
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=2)
for expected_gain in [2.0, 2.0]: # test 2 iterations catch more corner cases.
# grad pass 1
in_data = Tensor([0.0, 1.0])
if not cpu:
in_data = in_data.cuda()
out = model(in_data)
out.sum().backward()
# grad pass 2
in_data = Tensor([1.0, 0.0])
if not cpu:
in_data = in_data.cuda()
out = model(in_data)
out.sum().backward()
# stepping it. Note that if we did more than 2 passes as promised by the
# num_gradients_to_accumulate argument above, AdaScale is not be able to
# detect that mistake for now. The result will just be wrong in that case.
assert np.allclose(optim.gain(), expected_gain), optim.gain()
optim.step()
optim.zero_grad()
@skip_if_no_gpu
def test_grad_accum_gpu():
"""Test the basic functionality on GPU with gradient accumulation without DDP"""
test_grad_accum_cpu(cpu=False)
@skip_if_no_gpu
def test_state_checkpointing():
""" Test state checkpointing on GPU since that's the common case.
AdaScale doesn't have distributed state. Otherwise, it will need
a unit test for checkpointing with DDP.
"""
# Constants.
accum_steps = 3
in_dim = 5
# Setup.
def make_model_and_optim():
model = Linear(in_dim, 2, bias=False)
model = model.cuda()
optim = AdaScale(SGD(model.parameters(), lr=0.1, momentum=0.9), num_gradients_to_accumulate=accum_steps)
return model, optim
model, optim = make_model_and_optim()
# Run a bit.
def run_a_bit(replay_data=None):
print("running")
data = []
replay_data_idx = 0
for _ in range(6): # run some steps
for i in range(accum_steps):
if replay_data is None:
in_data = torch.rand(in_dim).cuda()
data.append(in_data)
else:
in_data = replay_data[replay_data_idx]
replay_data_idx += 1
out = model(in_data)
out.sum().backward()
# print(out.sum().item())
print(model.weight.grad)
if i == accum_steps - 1:
optim.step()
optim.zero_grad()
return out, data
run_a_bit()
with tempfile.NamedTemporaryFile() as f:
temp_file_name = f.name
# Save a checkpoint.
torch.save({"model": model.state_dict(), "optim": optim.state_dict()}, temp_file_name)
# Train more.
out, replay_data = run_a_bit()
# Save the gain and out.
expected_out = out.sum().item()
expected_gain = optim.gain()
# Load back the checkpoint.
model, optim = make_model_and_optim() # They both need to start afresh.
ckpt = torch.load(temp_file_name)
model.load_state_dict(ckpt["model"])
optim.load_state_dict(ckpt["optim"])
# Train the same steps.
out, _ = run_a_bit(replay_data)
# Assert the results.
assert np.allclose(out.sum().item(), expected_out), out.sum().item()
assert np.allclose(optim.gain(), expected_gain), optim.gain()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment