"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b9feed87958c27074b0618cc543696c05f58e2c9"
Commit 818c40c3 authored by pangjm's avatar pangjm
Browse files

minor edit to build_hook

parent fc5319b6
...@@ -182,12 +182,12 @@ class Runner(object): ...@@ -182,12 +182,12 @@ class Runner(object):
if not inserted: if not inserted:
self._hooks.insert(0, hook) self._hooks.insert(0, hook)
def build_hook(self, hook, args): def build_hook(self, args, hook_type=None):
assert issubclass(hook, Hook), '"hook" must be a Hook object' if isinstance(args, Hook):
if isinstance(args, dict): return args
self.register_hook(hook(**args)) elif isinstance(args, dict):
elif isinstance(args, Hook): assert issubclass(hook_type, Hook)
self.register_hook(args) return hook_type(**args)
else: else:
raise TypeError('"args" must be either a Hook object' raise TypeError('"args" must be either a Hook object'
' or dict, not {}'.format(type(args))) ' or dict, not {}'.format(type(args)))
...@@ -356,8 +356,8 @@ class Runner(object): ...@@ -356,8 +356,8 @@ class Runner(object):
if checkpoint_config is None: if checkpoint_config is None:
checkpoint_config = {} checkpoint_config = {}
self.register_lr_hooks(lr_config) self.register_lr_hooks(lr_config)
self.build_hook(OptimizerHook, optimizer_config) self.register_hook(self.build_hook(optimizer_config, OptimizerHook))
self.build_hook(CheckpointHook, checkpoint_config) self.register_hook(self.build_hook(checkpoint_config, CheckpointHook))
self.register_hook(IterTimerHook()) self.register_hook(IterTimerHook())
if log_config is not None: if log_config is not None:
self.register_logger_hooks(log_config) self.register_logger_hooks(log_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment