Unverified Commit 22e73d69 authored by Shilong Zhang's avatar Shilong Zhang Committed by GitHub
Browse files

[Feature] Detect anomalous parameters (#1547)

* detect detect_anomalous_params

* fix default value

* merge two case

* fix none case

* add unitest

* fix typo

* change level to error

* fix type

* add more details in docstr
parent 519b4ec0
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import logging
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
...@@ -20,9 +21,27 @@ except ImportError: ...@@ -20,9 +21,27 @@ except ImportError:
@HOOKS.register_module() @HOOKS.register_module()
class OptimizerHook(Hook): class OptimizerHook(Hook):
"""A hook contains custom operations for the optimizer.
def __init__(self, grad_clip=None): Args:
grad_clip (dict, optional): A config dict to control the clip_grad.
Default: None.
detect_anomalous_params (bool): This option is only used for
debugging which will slow down the training speed.
Detect anomalous parameters that are not included in
the computational graph with `loss` as the root.
There are two cases
- Parameters were not used during
forward pass.
- Parameters were not used to produce
loss.
Default: False.
"""
def __init__(self, grad_clip=None, detect_anomalous_params=False):
self.grad_clip = grad_clip self.grad_clip = grad_clip
self.detect_anomalous_params = detect_anomalous_params
def clip_grads(self, params): def clip_grads(self, params):
params = list( params = list(
...@@ -32,7 +51,10 @@ class OptimizerHook(Hook): ...@@ -32,7 +51,10 @@ class OptimizerHook(Hook):
def after_train_iter(self, runner): def after_train_iter(self, runner):
runner.optimizer.zero_grad() runner.optimizer.zero_grad()
if self.detect_anomalous_params:
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
runner.outputs['loss'].backward() runner.outputs['loss'].backward()
if self.grad_clip is not None: if self.grad_clip is not None:
grad_norm = self.clip_grads(runner.model.parameters()) grad_norm = self.clip_grads(runner.model.parameters())
if grad_norm is not None: if grad_norm is not None:
...@@ -41,6 +63,32 @@ class OptimizerHook(Hook): ...@@ -41,6 +63,32 @@ class OptimizerHook(Hook):
runner.outputs['num_samples']) runner.outputs['num_samples'])
runner.optimizer.step() runner.optimizer.step()
def detect_anomalous_parameters(self, loss, runner):
logger = runner.logger
parameters_in_graph = set()
visited = set()
def traverse(grad_fn):
if grad_fn is None:
return
if grad_fn not in visited:
visited.add(grad_fn)
if hasattr(grad_fn, 'variable'):
parameters_in_graph.add(grad_fn.variable)
parents = grad_fn.next_functions
if parents is not None:
for parent in parents:
grad_fn = parent[0]
traverse(grad_fn)
traverse(loss.grad_fn)
for n, p in runner.model.named_parameters():
if p not in parameters_in_graph and p.requires_grad:
logger.log(
level=logging.ERROR,
msg=f'{n} with shape {p.size()} is not '
f'in the computational graph \n')
@HOOKS.register_module() @HOOKS.register_module()
class GradientCumulativeOptimizerHook(OptimizerHook): class GradientCumulativeOptimizerHook(OptimizerHook):
......
...@@ -12,7 +12,7 @@ import re ...@@ -12,7 +12,7 @@ import re
import shutil import shutil
import sys import sys
import tempfile import tempfile
from unittest.mock import MagicMock, call, patch from unittest.mock import MagicMock, Mock, call, patch
import pytest import pytest
import torch import torch
...@@ -39,6 +39,85 @@ sys.modules['petrel_client'] = MagicMock() ...@@ -39,6 +39,85 @@ sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock() sys.modules['petrel_client.client'] = MagicMock()
def test_optimizerhook():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
self.conv2 = nn.Conv2d(
in_channels=2,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
self.conv3 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x1, x2
model = Model()
x = torch.rand(1, 1, 3, 3)
dummy_runner = Mock()
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
dummy_runner.optimizer.step = Mock(return_value=None)
dummy_runner.model = model
dummy_runner.outputs = dict()
dummy_runner.outputs['num_samples'] = 0
class DummyLogger():
def __init__(self):
self.msg = ''
def log(self, msg=None, **kwargs):
self.msg += msg
dummy_runner.logger = DummyLogger()
optimizer_hook = OptimizerHook(
dict(max_norm=2), detect_anomalous_params=True)
dummy_runner.outputs['loss'] = model(x)[0].sum()
optimizer_hook.after_train_iter(dummy_runner)
# assert the parameters of conv2 and conv3 are not in the
# computational graph which is with x1.sum() as root.
assert 'conv2.weight' in dummy_runner.logger.msg
assert 'conv2.bias' in dummy_runner.logger.msg
assert 'conv3.weight' in dummy_runner.logger.msg
assert 'conv3.bias' in dummy_runner.logger.msg
assert 'conv1.weight' not in dummy_runner.logger.msg
assert 'conv1.bias' not in dummy_runner.logger.msg
dummy_runner.outputs['loss'] = model(x)[1].sum()
dummy_runner.logger.msg = ''
optimizer_hook.after_train_iter(dummy_runner)
# assert the parameters of conv3 are not in the computational graph
assert 'conv3.weight' in dummy_runner.logger.msg
assert 'conv3.bias' in dummy_runner.logger.msg
assert 'conv2.weight' not in dummy_runner.logger.msg
assert 'conv2.bias' not in dummy_runner.logger.msg
assert 'conv1.weight' not in dummy_runner.logger.msg
assert 'conv1.bias' not in dummy_runner.logger.msg
def test_checkpoint_hook(tmp_path): def test_checkpoint_hook(tmp_path):
"""xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook.""" """xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment