"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "7b18556c4aee85e3905444b4efa93842c727bc1e"
Commit ecfce392 authored by Kai Chen's avatar Kai Chen
Browse files

add EmptyCacheHook and rename some modules

parent f4550cd3
from .hook import Hook from .hook import Hook
from .checkpoint_saver import CheckpointHook from .checkpoint import CheckpointHook
from .closure import ClosureHook from .closure import ClosureHook
from .lr_updater import LrUpdaterHook from .lr_updater import LrUpdaterHook
from .optimizer_stepper import OptimizerHook from .optimizer import OptimizerHook
from .iter_timer import IterTimerHook from .iter_timer import IterTimerHook
from .sampler_seed import DistSamplerSeedHook from .sampler_seed import DistSamplerSeedHook
from .memory import EmptyCacheHook
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook, from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
TensorboardLoggerHook) TensorboardLoggerHook)
__all__ = [ __all__ = [
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'TextLoggerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook',
'PaviLoggerHook', 'TensorboardLoggerHook' 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook'
] ]
import torch
from .hook import Hook
class EmptyCacheHook(Hook):
def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
self._before_epoch = before_epoch
self._after_epoch = after_epoch
self._after_iter = after_iter
def after_iter(self, runner):
if self._after_iter:
torch.cuda.empty_cache()
def before_epoch(self, runner):
if self._before_epoch:
torch.cuda.empty_cache()
def after_epoch(self, runner):
if self._after_epoch:
torch.cuda.empty_cache()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment