Unverified Commit 41ca64cb authored by Wang Xinjiang's avatar Wang Xinjiang Committed by GitHub
Browse files

add gradient norm to logger (#326)

parent e7a446ae
...@@ -14,11 +14,15 @@ class OptimizerHook(Hook): ...@@ -14,11 +14,15 @@ class OptimizerHook(Hook):
params = list( params = list(
filter(lambda p: p.requires_grad and p.grad is not None, params)) filter(lambda p: p.requires_grad and p.grad is not None, params))
if len(params) > 0: if len(params) > 0:
clip_grad.clip_grad_norm_(params, **self.grad_clip) return clip_grad.clip_grad_norm_(params, **self.grad_clip)
def after_train_iter(self, runner): def after_train_iter(self, runner):
runner.optimizer.zero_grad() runner.optimizer.zero_grad()
runner.outputs['loss'].backward() runner.outputs['loss'].backward()
if self.grad_clip is not None: if self.grad_clip is not None:
self.clip_grads(runner.model.parameters()) grad_norm = self.clip_grads(runner.model.parameters())
if grad_norm is not None:
# Add grad norm to the logger
runner.log_buffer.update({'grad_norm': grad_norm},
runner.outputs['num_samples'])
runner.optimizer.step() runner.optimizer.step()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment