Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
54ee8d12
Commit
54ee8d12
authored
Mar 09, 2022
by
Xu Kai
Committed by
Frank Lee
Mar 11, 2022
Browse files
Fix/format colossalai/engine/paramhooks/(#350)
parent
e83970e3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
5 deletions
+8
-5
colossalai/engine/paramhooks/__init__.py
colossalai/engine/paramhooks/__init__.py
+2
-1
colossalai/engine/paramhooks/_param_hookmgr.py
colossalai/engine/paramhooks/_param_hookmgr.py
+6
-4
No files found.
colossalai/engine/paramhooks/__init__.py
View file @
54ee8d12
from
._param_hookmgr
import
BaseParamHookMgr
__all__
=
[
"BaseParamHookMgr"
]
\ No newline at end of file
__all__
=
[
"BaseParamHookMgr"
]
colossalai/engine/paramhooks/_param_hookmgr.py
View file @
54ee8d12
...
...
@@ -2,7 +2,9 @@ from typing import Callable, List
import
torch
import
functools
class
BaseParamHookMgr
(
object
):
def
__init__
(
self
,
param_list
:
List
[
torch
.
nn
.
Parameter
])
->
None
:
r
"""
register backward hook on every parameters of module
...
...
@@ -10,17 +12,17 @@ class BaseParamHookMgr(object):
self
.
_param_list
=
param_list
self
.
_hook_list
=
[]
def
register_backward_hooks
(
self
,
hook_call
:
Callable
)
->
None
:
def
register_backward_hooks
(
self
,
hook_call
:
Callable
)
->
None
:
r
"""
The hook_call will be called every time a gradient with respect to the a param in self.param_list
is computed.
The hook_call will be called every time a gradient with respect to the a param in self.param_list
is computed.
The hook should have the following signature:
```
hook(param, grad) -> Tensor or None
```
"""
if
not
torch
.
is_grad_enabled
():
return
# don't register grad hooks if grad isn't enabled
return
# don't register grad hooks if grad isn't enabled
for
p
in
self
.
_param_list
:
if
p
.
requires_grad
and
not
hasattr
(
p
,
'_base_param_hook'
):
handle
=
p
.
register_hook
(
functools
.
partial
(
hook_call
,
p
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment