"vscode:/vscode.git/clone" did not exist on "6e01e8c1c8ea323d30e3f57050469b2df66b56c6"
Unverified Commit 304f184c authored by Yosuke Shinya's avatar Yosuke Shinya Committed by GitHub
Browse files

[Fix] Fix gradient cumulative optimizer when resuming (#2093)

* Fix gradient cumulative optimizer when resuming

* add unit test
parent f527e43c
......@@ -147,24 +147,31 @@ class GradientCumulativeOptimizerHook(OptimizerHook):
'GradientCumulativeOptimizerHook may slightly decrease '
'performance if the model has BatchNorm layers.')
residual_iters = runner.max_iters - runner.iter
self.divisible_iters = (
residual_iters // self.cumulative_iters * self.cumulative_iters)
self.remainder_iters = residual_iters - self.divisible_iters
runner.max_iters // self.cumulative_iters * self.cumulative_iters)
self.remainder_iters = runner.max_iters - self.divisible_iters
self.initialized = True
def _get_loss_factor(self, runner):
"""Get loss division factor for the current iteration."""
if runner.iter < runner.max_iters - self.remainder_iters:
loss_factor = self.cumulative_iters
else:
loss_factor = self.remainder_iters
runner.logger.warning(
f'Loss will be divided by {loss_factor} in the last '
f'{self.remainder_iters} iterations because they are not '
f'enough for {self.cumulative_iters} cumulative_iters.')
assert loss_factor > 0
return loss_factor
def after_train_iter(self, runner):
if not self.initialized:
self._init(runner)
if runner.iter < self.divisible_iters:
loss_factor = self.cumulative_iters
else:
loss_factor = self.remainder_iters
loss = runner.outputs['loss']
loss = loss / loss_factor
loss = runner.outputs['loss'] / self._get_loss_factor(runner)
loss.backward()
if (self.every_n_iters(runner, self.cumulative_iters)
......@@ -310,13 +317,7 @@ if (TORCH_VERSION != 'parrots'
if not self.initialized:
self._init(runner)
if runner.iter < self.divisible_iters:
loss_factor = self.cumulative_iters
else:
loss_factor = self.remainder_iters
loss = runner.outputs['loss']
loss = loss / loss_factor
loss = runner.outputs['loss'] / self._get_loss_factor(runner)
self.loss_scaler.scale(loss).backward()
if (self.every_n_iters(runner, self.cumulative_iters)
......@@ -504,15 +505,7 @@ else:
if not self.initialized:
self._init(runner)
if runner.iter < self.divisible_iters:
loss_factor = self.cumulative_iters
else:
loss_factor = self.remainder_iters
loss = runner.outputs['loss']
loss = loss / loss_factor
# scale the loss value
loss = runner.outputs['loss'] / self._get_loss_factor(runner)
scaled_loss = loss * self.loss_scaler.loss_scale
scaled_loss.backward()
......
......@@ -1927,6 +1927,48 @@ def test_gradient_cumulative_optimizer_hook():
grad_clip=dict(max_norm=0.2), cumulative_iters=3)
assert optimizer_hook.has_batch_norm(model)
def calc_loss_factors(runner):
optimizer_hook = GradientCumulativeOptimizerHook(
grad_clip=dict(max_norm=0.2), cumulative_iters=3)
optimizer_hook._init(runner)
loss_factors = []
for current_iter in range(runner._iter, runner._max_iters):
runner._iter = current_iter
loss_factor = optimizer_hook._get_loss_factor(runner)
loss_factors.append(loss_factor)
shutil.rmtree(runner.work_dir)
return loss_factors
# test loss_factor with EpochBasedRunner
runner = build_toy_runner(dict(type='EpochBasedRunner', max_epochs=2))
runner._max_iters = 6 # max_epochs * len(data_loader)
assert calc_loss_factors(runner) == [3] * 6
runner = build_toy_runner(dict(type='EpochBasedRunner', max_epochs=2))
runner._max_iters = 8 # max_epochs * len(data_loader)
assert calc_loss_factors(runner) == [3] * 6 + [2, 2]
runner = build_toy_runner(dict(type='EpochBasedRunner', max_epochs=2))
runner._max_iters = 10 # max_epochs * len(data_loader)
assert calc_loss_factors(runner) == [3] * 9 + [1]
runner = build_toy_runner(dict(type='EpochBasedRunner', max_epochs=2))
runner._max_iters = 10 # max_epochs * len(data_loader)
runner._iter = 5 # resume
assert calc_loss_factors(runner) == [3] * 4 + [1]
# test loss_factor with IterBasedRunner
runner = build_toy_runner(dict(type='IterBasedRunner', max_iters=6))
assert calc_loss_factors(runner) == [3] * 6
runner = build_toy_runner(dict(type='IterBasedRunner', max_iters=7))
assert calc_loss_factors(runner) == [3] * 6 + [1]
runner = build_toy_runner(dict(type='IterBasedRunner', max_iters=8))
assert calc_loss_factors(runner) == [3] * 6 + [2, 2]
runner = build_toy_runner(dict(type='IterBasedRunner', max_iters=6))
runner._iter = 3 # resume
assert calc_loss_factors(runner) == [3] * 3
runner = build_toy_runner(dict(type='IterBasedRunner', max_iters=8))
runner._iter = 3 # resume
assert calc_loss_factors(runner) == [3] * 3 + [2, 2]
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment