Unverified Commit 22e73d69 authored by Shilong Zhang's avatar Shilong Zhang Committed by GitHub
Browse files

[Feature] Detect anomalous parameters (#1547)

* detect detect_anomalous_params

* fix default value

* merge two case

* fix none case

* add unitest

* fix typo

* change level to error

* fix type

* add more details in docstr
parent 519b4ec0
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
from collections import defaultdict
from itertools import chain
......@@ -20,9 +21,27 @@ except ImportError:
@HOOKS.register_module()
class OptimizerHook(Hook):
"""A hook contains custom operations for the optimizer.
def __init__(self, grad_clip=None):
Args:
grad_clip (dict, optional): A config dict to control the clip_grad.
Default: None.
detect_anomalous_params (bool): This option is only used for
debugging which will slow down the training speed.
Detect anomalous parameters that are not included in
the computational graph with `loss` as the root.
There are two cases
- Parameters were not used during
forward pass.
- Parameters were not used to produce
loss.
Default: False.
"""
def __init__(self, grad_clip=None, detect_anomalous_params=False):
self.grad_clip = grad_clip
self.detect_anomalous_params = detect_anomalous_params
def clip_grads(self, params):
params = list(
......@@ -32,7 +51,10 @@ class OptimizerHook(Hook):
def after_train_iter(self, runner):
runner.optimizer.zero_grad()
if self.detect_anomalous_params:
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
runner.outputs['loss'].backward()
if self.grad_clip is not None:
grad_norm = self.clip_grads(runner.model.parameters())
if grad_norm is not None:
......@@ -41,6 +63,32 @@ class OptimizerHook(Hook):
runner.outputs['num_samples'])
runner.optimizer.step()
def detect_anomalous_parameters(self, loss, runner):
logger = runner.logger
parameters_in_graph = set()
visited = set()
def traverse(grad_fn):
if grad_fn is None:
return
if grad_fn not in visited:
visited.add(grad_fn)
if hasattr(grad_fn, 'variable'):
parameters_in_graph.add(grad_fn.variable)
parents = grad_fn.next_functions
if parents is not None:
for parent in parents:
grad_fn = parent[0]
traverse(grad_fn)
traverse(loss.grad_fn)
for n, p in runner.model.named_parameters():
if p not in parameters_in_graph and p.requires_grad:
logger.log(
level=logging.ERROR,
msg=f'{n} with shape {p.size()} is not '
f'in the computational graph \n')
@HOOKS.register_module()
class GradientCumulativeOptimizerHook(OptimizerHook):
......
......@@ -12,7 +12,7 @@ import re
import shutil
import sys
import tempfile
from unittest.mock import MagicMock, call, patch
from unittest.mock import MagicMock, Mock, call, patch
import pytest
import torch
......@@ -39,6 +39,85 @@ sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
def test_optimizerhook():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
self.conv2 = nn.Conv2d(
in_channels=2,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
self.conv3 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x1, x2
model = Model()
x = torch.rand(1, 1, 3, 3)
dummy_runner = Mock()
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
dummy_runner.optimizer.step = Mock(return_value=None)
dummy_runner.model = model
dummy_runner.outputs = dict()
dummy_runner.outputs['num_samples'] = 0
class DummyLogger():
def __init__(self):
self.msg = ''
def log(self, msg=None, **kwargs):
self.msg += msg
dummy_runner.logger = DummyLogger()
optimizer_hook = OptimizerHook(
dict(max_norm=2), detect_anomalous_params=True)
dummy_runner.outputs['loss'] = model(x)[0].sum()
optimizer_hook.after_train_iter(dummy_runner)
# assert the parameters of conv2 and conv3 are not in the
# computational graph which is with x1.sum() as root.
assert 'conv2.weight' in dummy_runner.logger.msg
assert 'conv2.bias' in dummy_runner.logger.msg
assert 'conv3.weight' in dummy_runner.logger.msg
assert 'conv3.bias' in dummy_runner.logger.msg
assert 'conv1.weight' not in dummy_runner.logger.msg
assert 'conv1.bias' not in dummy_runner.logger.msg
dummy_runner.outputs['loss'] = model(x)[1].sum()
dummy_runner.logger.msg = ''
optimizer_hook.after_train_iter(dummy_runner)
# assert the parameters of conv3 are not in the computational graph
assert 'conv3.weight' in dummy_runner.logger.msg
assert 'conv3.bias' in dummy_runner.logger.msg
assert 'conv2.weight' not in dummy_runner.logger.msg
assert 'conv2.bias' not in dummy_runner.logger.msg
assert 'conv1.weight' not in dummy_runner.logger.msg
assert 'conv1.bias' not in dummy_runner.logger.msg
def test_checkpoint_hook(tmp_path):
"""xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment