Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b3b89865
Unverified
Commit
b3b89865
authored
Dec 05, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 05, 2022
Browse files
[Gemini] ParamOpHook -> ColoParamOpHook (#2080)
parent
4f21c9e8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
37 additions
and
36 deletions
+37
-36
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
+3
-3
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
+2
-2
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+4
-4
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+4
-3
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+4
-4
colossalai/tensor/param_op_hook.py
colossalai/tensor/param_op_hook.py
+18
-18
colossalai/zero/utils/gemini_hook.py
colossalai/zero/utils/gemini_hook.py
+2
-2
No files found.
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
View file @
b3b89865
...
...
@@ -3,7 +3,7 @@ import torch.nn
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.ophooks.runtime_mem_tracer_hook
import
GradMemTracerHook
,
ParamMemTracerHook
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.tensor.param_op_hook
import
Colo
ParamOpHookManager
__all__
=
[
'RuntimeMemTracer'
]
...
...
@@ -53,12 +53,12 @@ class RuntimeMemTracer():
args
,
kwargs
=
_cast_float
(
args
,
self
.
dtype
),
_cast_float
(
kwargs
,
self
.
dtype
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
_pre_forward
()
with
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
with
Colo
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
return
outputs
def
backward
(
self
,
loss
):
with
self
.
param_op_hook
.
switch_to_backward
(),
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
with
self
.
param_op_hook
.
switch_to_backward
(),
Colo
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
loss
.
backward
()
self
.
_post_backward
()
...
...
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
View file @
b3b89865
...
...
@@ -8,7 +8,7 @@ import torch
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.tensor_utils
import
alloc_storage
,
free_storage
from
colossalai.tensor.param_op_hook
import
ParamOpHook
from
colossalai.tensor.param_op_hook
import
Colo
ParamOpHook
class
TrainingPhase
(
Enum
):
...
...
@@ -39,7 +39,7 @@ class GradMemTracerHook():
hook
.
remove
()
class
ParamMemTracerHook
(
ParamOpHook
):
class
ParamMemTracerHook
(
Colo
ParamOpHook
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
...
...
colossalai/nn/parallel/data_parallel.py
View file @
b3b89865
...
...
@@ -12,7 +12,7 @@ from colossalai.logging import get_dist_logger
from
colossalai.nn.parallel.utils
import
get_temp_total_chunk_on_cuda
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor.colo_parameter
import
ColoParameter
,
ColoTensor
,
ColoTensorSpec
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.tensor.param_op_hook
import
Colo
ParamOpHookManager
from
colossalai.utils
import
get_current_device
from
colossalai.zero.utils.gemini_hook
import
GeminiZeROHook
...
...
@@ -259,7 +259,7 @@ class ZeroDDP(ColoDDP):
args
,
kwargs
=
_cast_float
(
args
,
torch
.
half
),
_cast_float
(
kwargs
,
torch
.
half
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
gemini_manager
.
pre_iter
(
*
args
)
with
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
with
Colo
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
if
self
.
force_outputs_fp32
:
return
_cast_float
(
outputs
,
torch
.
float
)
...
...
@@ -280,12 +280,12 @@ class ZeroDDP(ColoDDP):
self
.
gemini_manager
.
post_iter
()
def
backward
(
self
,
loss
:
torch
.
Tensor
):
with
self
.
param_op_hook
.
switch_to_backward
(),
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
with
self
.
param_op_hook
.
switch_to_backward
(),
Colo
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
loss
.
backward
()
self
.
_post_backward
()
def
backward_by_grad
(
self
,
tensor
,
grad
):
with
self
.
param_op_hook
.
switch_to_backward
(),
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
with
self
.
param_op_hook
.
switch_to_backward
(),
Colo
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
torch
.
autograd
.
backward
(
tensor
,
grad
)
self
.
_post_backward
()
...
...
colossalai/tensor/__init__.py
View file @
b3b89865
...
...
@@ -5,13 +5,14 @@ from .comm_spec import CollectiveCommPattern, CommSpec
from
.compute_spec
import
ComputePattern
,
ComputeSpec
from
.dist_spec_mgr
import
DistSpecManager
from
.distspec
import
ReplicaSpec
,
ShardSpec
from
.param_op_hook
import
ParamOpHook
,
ParamOpHookManager
from
.param_op_hook
import
Colo
ParamOpHook
,
Colo
ParamOpHookManager
from
.process_group
import
ProcessGroup
from
.tensor_spec
import
ColoTensorSpec
from
.utils
import
convert_dim_partition_dict
,
convert_parameter
,
merge_same_dim_mesh_list
,
named_params_with_colotensor
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'ComputeSpec'
,
'named_params_with_colotensor'
,
'ColoParameter'
,
'distspec'
,
'DistSpecManager'
,
'ParamOpHook'
,
'ParamOpHookManager'
,
'ProcessGroup'
,
'ColoTensorSpec'
,
'ShardSpec'
,
'ReplicaSpec'
,
'CommSpec'
,
'CollectiveCommPattern'
,
'convert_dim_partition_dict'
,
'merge_same_dim_mesh_list'
'distspec'
,
'DistSpecManager'
,
'ColoParamOpHook'
,
'ColoParamOpHookManager'
,
'ProcessGroup'
,
'ColoTensorSpec'
,
'ShardSpec'
,
'ReplicaSpec'
,
'CommSpec'
,
'CollectiveCommPattern'
,
'convert_dim_partition_dict'
,
'merge_same_dim_mesh_list'
]
colossalai/tensor/colo_parameter.py
View file @
b3b89865
...
...
@@ -4,7 +4,7 @@ import torch
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
colossalai.tensor.const
import
TensorType
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.tensor.param_op_hook
import
Colo
ParamOpHookManager
from
colossalai.tensor.tensor_spec
import
ColoTensorSpec
...
...
@@ -58,18 +58,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
if
ParamOpHookManager
.
has_hook
():
if
Colo
ParamOpHookManager
.
has_hook
():
if
not
func
.
__name__
.
startswith
(
'__'
):
if
kwargs
is
None
:
kwargs
=
{}
params
=
filter_args
(
lambda
arg
:
isinstance
(
arg
,
ColoParameter
),
*
args
,
*
kwargs
.
values
())
if
len
(
params
)
>
0
:
with
torch
.
_C
.
DisableTorchFunction
():
new_args
=
ParamOpHookManager
.
pre_op
(
params
,
*
args
,
*
kwargs
.
values
())
new_args
=
Colo
ParamOpHookManager
.
pre_op
(
params
,
*
args
,
*
kwargs
.
values
())
args
,
kwargs
=
replace_args
(
args
,
kwargs
,
new_args
)
ret
=
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
with
torch
.
_C
.
DisableTorchFunction
():
ret
=
ParamOpHookManager
.
post_op
(
params
,
ret
)
ret
=
Colo
ParamOpHookManager
.
post_op
(
params
,
ret
)
return
ret
return
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
...
...
colossalai/tensor/param_op_hook.py
View file @
b3b89865
...
...
@@ -8,7 +8,7 @@ from colossalai.tensor.colo_tensor import ColoTensor
from
colossalai.tensor.tensor_spec
import
ColoTensorSpec
class
ParamOpHook
(
ABC
):
class
Colo
ParamOpHook
(
ABC
):
"""Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
...
...
@@ -32,68 +32,68 @@ class ParamOpHook(ABC):
pass
class
ParamOpHookManager
:
class
Colo
ParamOpHookManager
:
"""Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``.
"""
hooks
:
Tuple
[
ParamOpHook
,
...]
=
tuple
()
hooks
:
Tuple
[
Colo
ParamOpHook
,
...]
=
tuple
()
@
staticmethod
@
contextmanager
def
use_hooks
(
*
hooks
:
ParamOpHook
):
def
use_hooks
(
*
hooks
:
Colo
ParamOpHook
):
"""Change the param op hooks you use. Nested calling is allowed.
Example:
>>> with ParamOpHookManager.use_hooks(*hooks):
>>> with
Colo
ParamOpHookManager.use_hooks(*hooks):
>>> do_something()
>>> with ParamOpHookManager.use_hooks():
>>> with
Colo
ParamOpHookManager.use_hooks():
>>> // clear hooks
>>> do_something()
"""
try
:
old_param_op_hooks
=
ParamOpHookManager
.
hooks
ParamOpHookManager
.
hooks
=
hooks
old_param_op_hooks
=
Colo
ParamOpHookManager
.
hooks
Colo
ParamOpHookManager
.
hooks
=
hooks
yield
finally
:
ParamOpHookManager
.
hooks
=
old_param_op_hooks
Colo
ParamOpHookManager
.
hooks
=
old_param_op_hooks
@
staticmethod
def
_trigger_pre_forward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
for
hook
in
Colo
ParamOpHookManager
.
hooks
:
hook
.
pre_forward
(
params
)
@
staticmethod
def
_trigger_post_forward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
for
hook
in
Colo
ParamOpHookManager
.
hooks
:
hook
.
post_forward
(
params
)
@
staticmethod
def
_trigger_pre_backward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
for
hook
in
Colo
ParamOpHookManager
.
hooks
:
hook
.
pre_backward
(
params
)
@
staticmethod
def
_trigger_post_backward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
for
hook
in
Colo
ParamOpHookManager
.
hooks
:
hook
.
post_backward
(
params
)
@
staticmethod
def
pre_op
(
params
:
List
[
torch
.
Tensor
],
*
args
:
Any
)
->
list
:
ParamOpHookManager
.
_trigger_pre_forward
(
params
)
Colo
ParamOpHookManager
.
_trigger_pre_forward
(
params
)
args_info
=
_get_colo_tensors_info
(
*
args
)
rets
=
PreFwdPostBwd
.
apply
(
params
,
*
args
)
return
_update_colo_tensors
(
args_info
,
*
rets
)
@
staticmethod
def
post_op
(
params
:
List
[
torch
.
Tensor
],
arg
:
Any
)
->
Any
:
ParamOpHookManager
.
_trigger_post_forward
(
params
)
Colo
ParamOpHookManager
.
_trigger_post_forward
(
params
)
arg_info
=
_get_colo_tensors_info
(
arg
)
ret
=
PostFwdPreBwd
.
apply
(
params
,
arg
)
return
_unpack_args
(
_update_colo_tensors
(
arg_info
,
ret
))
@
staticmethod
def
has_hook
()
->
bool
:
return
len
(
ParamOpHookManager
.
hooks
)
>
0
return
len
(
Colo
ParamOpHookManager
.
hooks
)
>
0
class
PreFwdPostBwd
(
torch
.
autograd
.
Function
):
...
...
@@ -105,7 +105,7 @@ class PreFwdPostBwd(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
*
grads
):
ParamOpHookManager
.
_trigger_post_backward
(
ctx
.
params
)
Colo
ParamOpHookManager
.
_trigger_post_backward
(
ctx
.
params
)
return
(
None
,)
+
grads
...
...
@@ -118,7 +118,7 @@ class PostFwdPreBwd(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
*
grads
):
ParamOpHookManager
.
_trigger_pre_backward
(
ctx
.
params
)
Colo
ParamOpHookManager
.
_trigger_pre_backward
(
ctx
.
params
)
return
(
None
,)
+
grads
...
...
colossalai/zero/utils/gemini_hook.py
View file @
b3b89865
...
...
@@ -7,7 +7,7 @@ import torch
from
colossalai.gemini
import
TensorState
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.tensor.param_op_hook
import
ParamOpHook
from
colossalai.tensor.param_op_hook
import
Colo
ParamOpHook
class
TrainingPhase
(
Enum
):
...
...
@@ -15,7 +15,7 @@ class TrainingPhase(Enum):
BACKWARD
=
1
class
GeminiZeROHook
(
ParamOpHook
):
class
GeminiZeROHook
(
Colo
ParamOpHook
):
def
__init__
(
self
,
gemini_manager
:
GeminiManager
)
->
None
:
super
().
__init__
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment