Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8d653af4
Commit
8d653af4
authored
Mar 02, 2022
by
Jiarui Fang
Committed by
Frank Lee
Mar 11, 2022
Browse files
add a common util for hooks registered on parameter. (#292)
parent
f867365a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
120 additions
and
0 deletions
+120
-0
colossalai/engine/paramhooks/__init__.py
colossalai/engine/paramhooks/__init__.py
+2
-0
colossalai/engine/paramhooks/_param_hookmgr.py
colossalai/engine/paramhooks/_param_hookmgr.py
+32
-0
tests/test_engine/test_engine/test_param_hook.py
tests/test_engine/test_engine/test_param_hook.py
+86
-0
No files found.
colossalai/engine/paramhooks/__init__.py
0 → 100644
View file @
8d653af4
from
._param_hookmgr
import
BaseParamHookMgr
__all__
=
[
"BaseParamHookMgr"
]
\ No newline at end of file
colossalai/engine/paramhooks/_param_hookmgr.py
0 → 100644
View file @
8d653af4
from
typing
import
Callable
,
List
import
torch
import
functools
class
BaseParamHookMgr
(
object
):
def
__init__
(
self
,
param_list
:
List
[
torch
.
nn
.
Parameter
])
->
None
:
r
"""
register backward hook on every parameters of module
"""
self
.
_param_list
=
param_list
self
.
_hook_list
=
[]
def
register_backward_hooks
(
self
,
hook_call
:
Callable
)
->
None
:
r
"""
The hook_call will be called every time a gradient with respect to the a param in self.param_list
is computed.
The hook should have the following signature:
```
hook(param, grad) -> Tensor or None
```
"""
if
not
torch
.
is_grad_enabled
():
return
# don't register grad hooks if grad isn't enabled
for
p
in
self
.
_param_list
:
if
p
.
requires_grad
and
not
hasattr
(
p
,
'_base_param_hook'
):
handle
=
p
.
register_hook
(
functools
.
partial
(
hook_call
,
p
))
p
.
_base_param_hook
=
handle
def
remove_hooks
(
self
):
for
p
in
self
.
_param_list
:
if
p
.
requires_grad
and
hasattr
(
p
,
'_base_param_hook'
):
p
.
_base_param_hook
.
remove
()
tests/test_engine/test_engine/test_param_hook.py
0 → 100644
View file @
8d653af4
import
pytest
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
torch
import
nn
import
torch
import
torch.nn.functional
as
F
import
copy
class
SubNet
(
nn
.
Module
):
def
__init__
(
self
,
out_features
)
->
None
:
super
().
__init__
()
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_features
))
def
forward
(
self
,
x
,
weight
):
return
F
.
linear
(
x
,
weight
,
self
.
bias
)
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
checkpoint
=
False
)
->
None
:
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
5
,
5
)
self
.
sub_fc
=
SubNet
(
5
)
self
.
fc2
=
nn
.
Linear
(
5
,
1
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
sub_fc
(
x
,
self
.
fc1
.
weight
)
x
=
self
.
fc1
(
x
)
x
=
self
.
fc2
(
x
)
return
x
def
net_data
():
return
(
torch
.
randn
(
2
,
5
,
dtype
=
torch
.
float
,
device
=
'cuda'
),)
def
allclose
(
tensor_a
:
torch
.
Tensor
,
tensor_b
:
torch
.
Tensor
,
loose
=
False
)
->
bool
:
if
loose
:
return
torch
.
allclose
(
tensor_a
,
tensor_b
,
atol
=
1e-3
,
rtol
=
1e-3
)
return
torch
.
allclose
(
tensor_a
,
tensor_b
)
def
test_base_param_hook
():
torch
.
manual_seed
(
0
)
model
=
Net
(
checkpoint
=
True
).
cuda
()
model
.
train
()
inputs
=
net_data
()
def
run_model
(
model
,
inputs
,
use_param_hook
=
False
):
if
use_param_hook
:
class
HooKWrapper
:
def
__init__
(
self
)
->
None
:
self
.
hook_triggered_times
=
0
def
wrapper_func
(
self
):
def
hook
(
param
,
grad
)
->
torch
.
Tensor
or
None
:
self
.
hook_triggered_times
+=
1
return
grad
return
hook
hookwrapper
=
HooKWrapper
()
param_list
=
[
p
for
p
in
model
.
parameters
()]
hook_mgr
=
BaseParamHookMgr
(
param_list
)
hook_mgr
.
register_backward_hooks
(
hookwrapper
.
wrapper_func
())
model
.
zero_grad
(
set_to_none
=
True
)
with
torch
.
cuda
.
amp
.
autocast
():
y
=
model
(
*
inputs
)
loss
=
y
.
sum
()
loss
.
backward
()
if
use_param_hook
:
hook_mgr
.
remove_hooks
()
return
hookwrapper
.
hook_triggered_times
model_copy
=
copy
.
deepcopy
(
model
)
run_model
(
model
,
inputs
,
False
)
ret2
=
run_model
(
model_copy
,
inputs
,
True
)
# Make sure param hook has only be fired once in case of parameter sharing
assert
ret2
==
len
(
list
(
model
.
parameters
()))
for
p
,
p_copy
in
zip
(
model
.
parameters
(),
model_copy
.
parameters
()):
assert
allclose
(
p
.
grad
,
p_copy
.
grad
),
f
"
{
p
.
grad
}
vs
{
p_copy
.
grad
}
"
if
__name__
==
'__main__'
:
test_base_param_hook
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment