Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
895c1c5e
Unverified
Commit
895c1c5e
authored
Jun 13, 2022
by
ver217
Committed by
GitHub
Jun 13, 2022
Browse files
[tensor] refactor param op hook (#1097)
* refactor param op hook * add docstr * fix bug
parent
1e9f9c22
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
76 additions
and
31 deletions
+76
-31
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+5
-5
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+2
-2
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+4
-4
colossalai/tensor/param_op_hook.py
colossalai/tensor/param_op_hook.py
+65
-20
No files found.
colossalai/nn/parallel/data_parallel.py
View file @
895c1c5e
...
@@ -4,8 +4,8 @@ from colossalai.core import global_context as gpc
...
@@ -4,8 +4,8 @@ from colossalai.core import global_context as gpc
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
functools
import
partial
from
functools
import
partial
from
colossalai.zero.utils.zero_hook_v2
import
ZeROHookV2
from
colossalai.zero.utils.zero_hook_v2
import
ZeROHookV2
from
colossalai.tensor.chunk
import
ChunkManager
,
TensorState
,
Chunk
from
colossalai.tensor.chunk
import
TensorState
,
Chunk
from
colossalai.tensor.param_op_hook
import
use_param_op_hooks
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
typing
import
Dict
from
typing
import
Dict
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
...
@@ -113,7 +113,7 @@ class ColoDDPV2(ColoDDP):
...
@@ -113,7 +113,7 @@ class ColoDDPV2(ColoDDP):
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
gemini_manager
.
pre_iter
()
self
.
gemini_manager
.
pre_iter
()
with
use_param_op
_hooks
(
self
.
param_op_hook
):
with
ParamOpHookManager
.
use
_hooks
(
self
.
param_op_hook
):
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
self
.
chunk_manager
.
exec_lazy_release
()
self
.
chunk_manager
.
exec_lazy_release
()
return
outputs
return
outputs
...
@@ -134,12 +134,12 @@ class ColoDDPV2(ColoDDP):
...
@@ -134,12 +134,12 @@ class ColoDDPV2(ColoDDP):
self
.
gemini_manager
.
post_iter
()
self
.
gemini_manager
.
post_iter
()
def
backward
(
self
,
loss
:
torch
.
Tensor
):
def
backward
(
self
,
loss
:
torch
.
Tensor
):
with
self
.
param_op_hook
.
switch_to_backward
(),
use_param_op
_hooks
(
self
.
param_op_hook
):
with
self
.
param_op_hook
.
switch_to_backward
(),
ParamOpHookManager
.
use
_hooks
(
self
.
param_op_hook
):
loss
.
backward
()
loss
.
backward
()
self
.
_post_backward
()
self
.
_post_backward
()
def
backward_by_grad
(
self
,
tensor
,
grad
):
def
backward_by_grad
(
self
,
tensor
,
grad
):
with
self
.
param_op_hook
.
switch_to_backward
(),
use_param_op
_hooks
(
self
.
param_op_hook
):
with
self
.
param_op_hook
.
switch_to_backward
(),
ParamOpHookManager
.
use
_hooks
(
self
.
param_op_hook
):
torch
.
autograd
.
backward
(
tensor
,
grad
)
torch
.
autograd
.
backward
(
tensor
,
grad
)
self
.
_post_backward
()
self
.
_post_backward
()
...
...
colossalai/tensor/__init__.py
View file @
895c1c5e
...
@@ -5,10 +5,10 @@ from .colo_parameter import ColoParameter
...
@@ -5,10 +5,10 @@ from .colo_parameter import ColoParameter
from
.utils
import
convert_parameter
,
named_params_with_colotensor
from
.utils
import
convert_parameter
,
named_params_with_colotensor
from
.
import
distspec
from
.
import
distspec
from
.dist_spec_mgr
import
DistSpecManager
from
.dist_spec_mgr
import
DistSpecManager
from
.param_op_hook
import
ParamOpHook
,
use_param_op_hooks
from
.param_op_hook
import
ParamOpHook
,
ParamOpHookManager
from
.chunk
import
ChunkManager
,
TensorState
from
.chunk
import
ChunkManager
,
TensorState
__all__
=
[
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'TensorSpec'
,
'ParallelAction'
,
'named_params_with_colotensor'
,
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'TensorSpec'
,
'ParallelAction'
,
'named_params_with_colotensor'
,
'ColoParameter'
,
'distspec'
,
'DistSpecManager'
,
'ParamOpHook'
,
'
use_param_op_hooks
'
,
'ChunkManager'
,
'TensorState'
'ColoParameter'
,
'distspec'
,
'DistSpecManager'
,
'ParamOpHook'
,
'
ParamOpHookManager
'
,
'ChunkManager'
,
'TensorState'
]
]
colossalai/tensor/colo_parameter.py
View file @
895c1c5e
...
@@ -3,7 +3,7 @@ from colossalai.tensor.const import TensorType
...
@@ -3,7 +3,7 @@ from colossalai.tensor.const import TensorType
import
torch
import
torch
from
colossalai.tensor
import
TensorSpec
,
distspec
from
colossalai.tensor
import
TensorSpec
,
distspec
from
copy
import
copy
from
copy
import
copy
from
colossalai.tensor.param_op_hook
import
_
ParamOpHook
Wrapper
,
PreFwdPostBwd
,
PostFwdPreBwd
from
colossalai.tensor.param_op_hook
import
ParamOpHook
Manager
from
typing
import
Optional
from
typing
import
Optional
...
@@ -48,17 +48,17 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
...
@@ -48,17 +48,17 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
@
classmethod
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
if
len
(
_
ParamOpHook
Wrapper
.
hooks
)
>
0
:
if
ParamOpHook
Manager
.
has_hook
()
:
if
not
func
.
__name__
.
startswith
(
'__'
):
if
not
func
.
__name__
.
startswith
(
'__'
):
params
=
list
(
filter
(
lambda
arg
:
isinstance
(
arg
,
ColoParameter
),
args
))
params
=
list
(
filter
(
lambda
arg
:
isinstance
(
arg
,
ColoParameter
),
args
))
if
kwargs
is
not
None
:
if
kwargs
is
not
None
:
params
.
extend
(
list
(
filter
(
lambda
arg
:
isinstance
(
arg
,
ColoParameter
),
kwargs
.
values
())))
params
.
extend
(
list
(
filter
(
lambda
arg
:
isinstance
(
arg
,
ColoParameter
),
kwargs
.
values
())))
if
len
(
params
)
>
0
:
if
len
(
params
)
>
0
:
with
torch
.
_C
.
DisableTorchFunction
():
with
torch
.
_C
.
DisableTorchFunction
():
args
=
P
reFwdPostBwd
.
apply
(
params
,
*
args
)
args
=
P
aramOpHookManager
.
pre_op
(
params
,
*
args
)
ret
=
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
ret
=
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
with
torch
.
_C
.
DisableTorchFunction
():
with
torch
.
_C
.
DisableTorchFunction
():
ret
=
P
ostFwdPreBwd
.
apply
(
params
,
ret
)
ret
=
P
aramOpHookManager
.
post_op
(
params
,
ret
)
return
ret
return
ret
return
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
return
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
...
...
colossalai/tensor/param_op_hook.py
View file @
895c1c5e
import
torch
import
torch
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
,
Any
class
ParamOpHook
(
ABC
):
class
ParamOpHook
(
ABC
):
"""Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
of ColoParameter.
"""
@
abstractmethod
@
abstractmethod
def
pre_forward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
def
pre_forward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
...
@@ -23,25 +28,78 @@ class ParamOpHook(ABC):
...
@@ -23,25 +28,78 @@ class ParamOpHook(ABC):
pass
pass
class
_ParamOpHookWrapper
:
class
ParamOpHookManager
:
"""Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``.
"""
hooks
:
Tuple
[
ParamOpHook
,
...]
=
tuple
()
hooks
:
Tuple
[
ParamOpHook
,
...]
=
tuple
()
@
staticmethod
@
contextmanager
def
use_hooks
(
*
hooks
:
ParamOpHook
):
"""Change the param op hooks you use. Nested calling is allowed.
Example::
>>> with ParamOpHookManager.use_hooks(*hooks):
>>> do_something()
>>> with ParamOpHookManager.use_hooks():
>>> // clear hooks
>>> do_something()
"""
try
:
old_param_op_hooks
=
ParamOpHookManager
.
hooks
ParamOpHookManager
.
hooks
=
hooks
yield
finally
:
ParamOpHookManager
.
hooks
=
old_param_op_hooks
@
staticmethod
def
_trigger_pre_forward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
hook
.
pre_forward
(
params
)
@
staticmethod
def
_trigger_post_forward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
hook
.
post_forward
(
params
)
@
staticmethod
def
_trigger_pre_backward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
hook
.
pre_backward
(
params
)
@
staticmethod
def
_trigger_post_backward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
hook
.
post_backward
(
params
)
@
staticmethod
def
pre_op
(
params
:
List
[
torch
.
Tensor
],
*
args
:
Any
)
->
Any
:
ParamOpHookManager
.
_trigger_pre_forward
(
params
)
return
PreFwdPostBwd
.
apply
(
params
,
*
args
)
@
staticmethod
def
post_op
(
params
:
List
[
torch
.
Tensor
],
args
:
Any
)
->
Any
:
ParamOpHookManager
.
_trigger_post_forward
(
params
)
return
PostFwdPreBwd
.
apply
(
params
,
args
)
@
staticmethod
def
has_hook
()
->
bool
:
return
len
(
ParamOpHookManager
.
hooks
)
>
0
class
PreFwdPostBwd
(
torch
.
autograd
.
Function
):
class
PreFwdPostBwd
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
params
,
*
args
):
def
forward
(
ctx
,
params
,
*
args
):
ctx
.
params
=
params
ctx
.
params
=
params
for
hook
in
_ParamOpHookWrapper
.
hooks
:
hook
.
pre_forward
(
ctx
.
params
)
if
len
(
args
)
==
1
:
if
len
(
args
)
==
1
:
return
args
[
0
]
return
args
[
0
]
return
args
return
args
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
grads
):
def
backward
(
ctx
,
*
grads
):
for
hook
in
_ParamOpHookWrapper
.
hooks
:
ParamOpHookManager
.
_trigger_post_backward
(
ctx
.
params
)
hook
.
post_backward
(
ctx
.
params
)
return
(
None
,)
+
grads
return
(
None
,)
+
grads
...
@@ -50,22 +108,9 @@ class PostFwdPreBwd(torch.autograd.Function):
...
@@ -50,22 +108,9 @@ class PostFwdPreBwd(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
params
,
args
):
def
forward
(
ctx
,
params
,
args
):
ctx
.
params
=
params
ctx
.
params
=
params
for
hook
in
_ParamOpHookWrapper
.
hooks
:
hook
.
post_forward
(
params
)
return
args
return
args
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
grads
):
def
backward
(
ctx
,
*
grads
):
for
hook
in
_ParamOpHookWrapper
.
hooks
:
ParamOpHookManager
.
_trigger_pre_backward
(
ctx
.
params
)
hook
.
pre_backward
(
ctx
.
params
)
return
(
None
,)
+
grads
return
(
None
,)
+
grads
@
contextmanager
def
use_param_op_hooks
(
*
hooks
:
ParamOpHook
):
try
:
old_param_op_hooks
=
_ParamOpHookWrapper
.
hooks
_ParamOpHookWrapper
.
hooks
=
hooks
yield
finally
:
_ParamOpHookWrapper
.
hooks
=
old_param_op_hooks
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment