Commit 54ee8d12 authored by Xu Kai's avatar Xu Kai Committed by Frank Lee
Browse files

Fix/format colossalai/engine/paramhooks/(#350)

parent e83970e3
from ._param_hookmgr import BaseParamHookMgr
__all__ = ["BaseParamHookMgr"]
......@@ -2,7 +2,9 @@ from typing import Callable, List
import torch
import functools
class BaseParamHookMgr(object):
def __init__(self, param_list: List[torch.nn.Parameter]) -> None:
r"""
register backward hook on every parameters of module
......@@ -10,7 +12,7 @@ class BaseParamHookMgr(object):
self._param_list = param_list
self._hook_list = []
def register_backward_hooks(self, hook_call : Callable) -> None:
def register_backward_hooks(self, hook_call: Callable) -> None:
r"""
The hook_call will be called every time a gradient with respect to the a param in self.param_list
is computed.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment