Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
54ee8d12
Commit
54ee8d12
authored
Mar 09, 2022
by
Xu Kai
Committed by
Frank Lee
Mar 11, 2022
Browse files
Fix/format colossalai/engine/paramhooks/(#350)
parent
e83970e3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
5 deletions
+8
-5
colossalai/engine/paramhooks/__init__.py
colossalai/engine/paramhooks/__init__.py
+2
-1
colossalai/engine/paramhooks/_param_hookmgr.py
colossalai/engine/paramhooks/_param_hookmgr.py
+6
-4
No files found.
colossalai/engine/paramhooks/__init__.py
View file @
54ee8d12
from
._param_hookmgr
import
BaseParamHookMgr
from
._param_hookmgr
import
BaseParamHookMgr
__all__
=
[
"BaseParamHookMgr"
]
__all__
=
[
"BaseParamHookMgr"
]
colossalai/engine/paramhooks/_param_hookmgr.py
View file @
54ee8d12
...
@@ -2,7 +2,9 @@ from typing import Callable, List
...
@@ -2,7 +2,9 @@ from typing import Callable, List
import
torch
import
torch
import
functools
import
functools
class
BaseParamHookMgr
(
object
):
class
BaseParamHookMgr
(
object
):
def
__init__
(
self
,
param_list
:
List
[
torch
.
nn
.
Parameter
])
->
None
:
def
__init__
(
self
,
param_list
:
List
[
torch
.
nn
.
Parameter
])
->
None
:
r
"""
r
"""
register backward hook on every parameters of module
register backward hook on every parameters of module
...
@@ -10,7 +12,7 @@ class BaseParamHookMgr(object):
...
@@ -10,7 +12,7 @@ class BaseParamHookMgr(object):
self
.
_param_list
=
param_list
self
.
_param_list
=
param_list
self
.
_hook_list
=
[]
self
.
_hook_list
=
[]
def
register_backward_hooks
(
self
,
hook_call
:
Callable
)
->
None
:
def
register_backward_hooks
(
self
,
hook_call
:
Callable
)
->
None
:
r
"""
r
"""
The hook_call will be called every time a gradient with respect to the a param in self.param_list
The hook_call will be called every time a gradient with respect to the a param in self.param_list
is computed.
is computed.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment