Commit 961c3388 authored by Kai Chen's avatar Kai Chen
Browse files

minor update of hooks

parent c211ab13
...@@ -4,4 +4,12 @@ from .closure import ClosureHook ...@@ -4,4 +4,12 @@ from .closure import ClosureHook
from .lr_updater import LrUpdaterHook from .lr_updater import LrUpdaterHook
from .optimizer_stepper import OptimizerHook from .optimizer_stepper import OptimizerHook
from .iter_timer import IterTimerHook from .iter_timer import IterTimerHook
from .logger import * from .sampler_seed import DistSamplerSeedHook
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
pavi_hook_connect, TensorboardLoggerHook)
__all__ = [
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'TextLoggerHook',
'PaviLoggerHook', 'pavi_hook_connect', 'TensorboardLoggerHook'
]
from .base import LoggerHook from .base import LoggerHook
from .pavi import PaviClient, PaviLoggerHook from .pavi import PaviLoggerHook, pavi_hook_connect
from .tensorboard import TensorboardLoggerHook from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook from .text import TextLoggerHook
__all__ = [
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'pavi_hook_connect',
'TensorboardLoggerHook'
]
...@@ -8,11 +8,13 @@ class OptimizerHook(Hook): ...@@ -8,11 +8,13 @@ class OptimizerHook(Hook):
def __init__(self, grad_clip=None): def __init__(self, grad_clip=None):
self.grad_clip = grad_clip self.grad_clip = grad_clip
def clip_grads(self, params):
clip_grad.clip_grad_norm_(
filter(lambda p: p.requires_grad, params), **self.grad_clip)
def after_train_iter(self, runner): def after_train_iter(self, runner):
runner.optimizer.zero_grad() runner.optimizer.zero_grad()
runner.outputs['loss'].backward() runner.outputs['loss'].backward()
if self.grad_clip is not None: if self.grad_clip is not None:
clip_grad.clip_grad_norm_( self.clip_grads(runner.model.parameters())
filter(lambda p: p.requires_grad, runner.model.parameters()),
**self.grad_clip)
runner.optimizer.step() runner.optimizer.step()
from .hook import Hook
class DistSamplerSeedHook(Hook):
def before_epoch(self, runner):
runner.data_loader.sampler.set_epoch(runner.epoch)
import sys import sys
from setuptools import find_packages, setup from setuptools import find_packages, setup
install_requires = ['numpy>=1.11.1', 'pyyaml', 'six', 'addict'] install_requires = ['numpy>=1.11.1', 'pyyaml', 'six', 'addict', 'requests']
if sys.version_info < (3, 3): if sys.version_info < (3, 3):
install_requires.append('backports.shutil_get_terminal_size') install_requires.append('backports.shutil_get_terminal_size')
if sys.version_info < (3, 4): if sys.version_info < (3, 4):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment