Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b3b89865
Unverified
Commit
b3b89865
authored
Dec 05, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 05, 2022
Browse files
[Gemini] ParamOpHook -> ColoParamOpHook (#2080)
parent
4f21c9e8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
37 additions
and
36 deletions
+37
-36
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
+3
-3
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
+2
-2
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+4
-4
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+4
-3
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+4
-4
colossalai/tensor/param_op_hook.py
colossalai/tensor/param_op_hook.py
+18
-18
colossalai/zero/utils/gemini_hook.py
colossalai/zero/utils/gemini_hook.py
+2
-2
No files found.
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
View file @
b3b89865
...
@@ -3,7 +3,7 @@ import torch.nn
...
@@ -3,7 +3,7 @@ import torch.nn
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.ophooks.runtime_mem_tracer_hook
import
GradMemTracerHook
,
ParamMemTracerHook
from
colossalai.gemini.ophooks.runtime_mem_tracer_hook
import
GradMemTracerHook
,
ParamMemTracerHook
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.tensor.param_op_hook
import
Colo
ParamOpHookManager
__all__
=
[
'RuntimeMemTracer'
]
__all__
=
[
'RuntimeMemTracer'
]
...
@@ -53,12 +53,12 @@ class RuntimeMemTracer():
...
@@ -53,12 +53,12 @@ class RuntimeMemTracer():
args
,
kwargs
=
_cast_float
(
args
,
self
.
dtype
),
_cast_float
(
kwargs
,
self
.
dtype
)
args
,
kwargs
=
_cast_float
(
args
,
self
.
dtype
),
_cast_float
(
kwargs
,
self
.
dtype
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
_pre_forward
()
self
.
_pre_forward
()
with
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
with
Colo
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
return
outputs
return
outputs
def
backward
(
self
,
loss
):
def
backward
(
self
,
loss
):
with
self
.
param_op_hook
.
switch_to_backward
(),
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
with
self
.
param_op_hook
.
switch_to_backward
(),
Colo
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
loss
.
backward
()
loss
.
backward
()
self
.
_post_backward
()
self
.
_post_backward
()
...
...
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
View file @
b3b89865
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.tensor_utils
import
alloc_storage
,
free_storage
from
colossalai.gemini.tensor_utils
import
alloc_storage
,
free_storage
from
colossalai.tensor.param_op_hook
import
ParamOpHook
from
colossalai.tensor.param_op_hook
import
Colo
ParamOpHook
class
TrainingPhase
(
Enum
):
class
TrainingPhase
(
Enum
):
...
@@ -39,7 +39,7 @@ class GradMemTracerHook():
...
@@ -39,7 +39,7 @@ class GradMemTracerHook():
hook
.
remove
()
hook
.
remove
()
class
ParamMemTracerHook
(
ParamOpHook
):
class
ParamMemTracerHook
(
Colo
ParamOpHook
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
...
colossalai/nn/parallel/data_parallel.py
View file @
b3b89865
...
@@ -12,7 +12,7 @@ from colossalai.logging import get_dist_logger
...
@@ -12,7 +12,7 @@ from colossalai.logging import get_dist_logger
from
colossalai.nn.parallel.utils
import
get_temp_total_chunk_on_cuda
from
colossalai.nn.parallel.utils
import
get_temp_total_chunk_on_cuda
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor.colo_parameter
import
ColoParameter
,
ColoTensor
,
ColoTensorSpec
from
colossalai.tensor.colo_parameter
import
ColoParameter
,
ColoTensor
,
ColoTensorSpec
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.tensor.param_op_hook
import
Colo
ParamOpHookManager
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.zero.utils.gemini_hook
import
GeminiZeROHook
from
colossalai.zero.utils.gemini_hook
import
GeminiZeROHook
...
@@ -259,7 +259,7 @@ class ZeroDDP(ColoDDP):
...
@@ -259,7 +259,7 @@ class ZeroDDP(ColoDDP):
args
,
kwargs
=
_cast_float
(
args
,
torch
.
half
),
_cast_float
(
kwargs
,
torch
.
half
)
args
,
kwargs
=
_cast_float
(
args
,
torch
.
half
),
_cast_float
(
kwargs
,
torch
.
half
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
gemini_manager
.
pre_iter
(
*
args
)
self
.
gemini_manager
.
pre_iter
(
*
args
)
with
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
with
Colo
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
if
self
.
force_outputs_fp32
:
if
self
.
force_outputs_fp32
:
return
_cast_float
(
outputs
,
torch
.
float
)
return
_cast_float
(
outputs
,
torch
.
float
)
...
@@ -280,12 +280,12 @@ class ZeroDDP(ColoDDP):
...
@@ -280,12 +280,12 @@ class ZeroDDP(ColoDDP):
self
.
gemini_manager
.
post_iter
()
self
.
gemini_manager
.
post_iter
()
def
backward
(
self
,
loss
:
torch
.
Tensor
):
def
backward
(
self
,
loss
:
torch
.
Tensor
):
with
self
.
param_op_hook
.
switch_to_backward
(),
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
with
self
.
param_op_hook
.
switch_to_backward
(),
Colo
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
loss
.
backward
()
loss
.
backward
()
self
.
_post_backward
()
self
.
_post_backward
()
def
backward_by_grad
(
self
,
tensor
,
grad
):
def
backward_by_grad
(
self
,
tensor
,
grad
):
with
self
.
param_op_hook
.
switch_to_backward
(),
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
with
self
.
param_op_hook
.
switch_to_backward
(),
Colo
ParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
torch
.
autograd
.
backward
(
tensor
,
grad
)
torch
.
autograd
.
backward
(
tensor
,
grad
)
self
.
_post_backward
()
self
.
_post_backward
()
...
...
colossalai/tensor/__init__.py
View file @
b3b89865
...
@@ -5,13 +5,14 @@ from .comm_spec import CollectiveCommPattern, CommSpec
...
@@ -5,13 +5,14 @@ from .comm_spec import CollectiveCommPattern, CommSpec
from
.compute_spec
import
ComputePattern
,
ComputeSpec
from
.compute_spec
import
ComputePattern
,
ComputeSpec
from
.dist_spec_mgr
import
DistSpecManager
from
.dist_spec_mgr
import
DistSpecManager
from
.distspec
import
ReplicaSpec
,
ShardSpec
from
.distspec
import
ReplicaSpec
,
ShardSpec
from
.param_op_hook
import
ParamOpHook
,
ParamOpHookManager
from
.param_op_hook
import
Colo
ParamOpHook
,
Colo
ParamOpHookManager
from
.process_group
import
ProcessGroup
from
.process_group
import
ProcessGroup
from
.tensor_spec
import
ColoTensorSpec
from
.tensor_spec
import
ColoTensorSpec
from
.utils
import
convert_dim_partition_dict
,
convert_parameter
,
merge_same_dim_mesh_list
,
named_params_with_colotensor
from
.utils
import
convert_dim_partition_dict
,
convert_parameter
,
merge_same_dim_mesh_list
,
named_params_with_colotensor
__all__
=
[
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'ComputeSpec'
,
'named_params_with_colotensor'
,
'ColoParameter'
,
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'ComputeSpec'
,
'named_params_with_colotensor'
,
'ColoParameter'
,
'distspec'
,
'DistSpecManager'
,
'ParamOpHook'
,
'ParamOpHookManager'
,
'ProcessGroup'
,
'ColoTensorSpec'
,
'ShardSpec'
,
'distspec'
,
'DistSpecManager'
,
'ColoParamOpHook'
,
'ColoParamOpHookManager'
,
'ProcessGroup'
,
'ColoTensorSpec'
,
'ReplicaSpec'
,
'CommSpec'
,
'CollectiveCommPattern'
,
'convert_dim_partition_dict'
,
'merge_same_dim_mesh_list'
'ShardSpec'
,
'ReplicaSpec'
,
'CommSpec'
,
'CollectiveCommPattern'
,
'convert_dim_partition_dict'
,
'merge_same_dim_mesh_list'
]
]
colossalai/tensor/colo_parameter.py
View file @
b3b89865
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
colossalai.tensor.const
import
TensorType
from
colossalai.tensor.const
import
TensorType
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.tensor.param_op_hook
import
Colo
ParamOpHookManager
from
colossalai.tensor.tensor_spec
import
ColoTensorSpec
from
colossalai.tensor.tensor_spec
import
ColoTensorSpec
...
@@ -58,18 +58,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
...
@@ -58,18 +58,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
@
classmethod
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
if
ParamOpHookManager
.
has_hook
():
if
Colo
ParamOpHookManager
.
has_hook
():
if
not
func
.
__name__
.
startswith
(
'__'
):
if
not
func
.
__name__
.
startswith
(
'__'
):
if
kwargs
is
None
:
if
kwargs
is
None
:
kwargs
=
{}
kwargs
=
{}
params
=
filter_args
(
lambda
arg
:
isinstance
(
arg
,
ColoParameter
),
*
args
,
*
kwargs
.
values
())
params
=
filter_args
(
lambda
arg
:
isinstance
(
arg
,
ColoParameter
),
*
args
,
*
kwargs
.
values
())
if
len
(
params
)
>
0
:
if
len
(
params
)
>
0
:
with
torch
.
_C
.
DisableTorchFunction
():
with
torch
.
_C
.
DisableTorchFunction
():
new_args
=
ParamOpHookManager
.
pre_op
(
params
,
*
args
,
*
kwargs
.
values
())
new_args
=
Colo
ParamOpHookManager
.
pre_op
(
params
,
*
args
,
*
kwargs
.
values
())
args
,
kwargs
=
replace_args
(
args
,
kwargs
,
new_args
)
args
,
kwargs
=
replace_args
(
args
,
kwargs
,
new_args
)
ret
=
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
ret
=
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
with
torch
.
_C
.
DisableTorchFunction
():
with
torch
.
_C
.
DisableTorchFunction
():
ret
=
ParamOpHookManager
.
post_op
(
params
,
ret
)
ret
=
Colo
ParamOpHookManager
.
post_op
(
params
,
ret
)
return
ret
return
ret
return
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
return
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
...
...
colossalai/tensor/param_op_hook.py
View file @
b3b89865
...
@@ -8,7 +8,7 @@ from colossalai.tensor.colo_tensor import ColoTensor
...
@@ -8,7 +8,7 @@ from colossalai.tensor.colo_tensor import ColoTensor
from
colossalai.tensor.tensor_spec
import
ColoTensorSpec
from
colossalai.tensor.tensor_spec
import
ColoTensorSpec
class
ParamOpHook
(
ABC
):
class
Colo
ParamOpHook
(
ABC
):
"""Hook which is triggered by each operation when operands contain ColoParameter.
"""Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
...
@@ -32,68 +32,68 @@ class ParamOpHook(ABC):
...
@@ -32,68 +32,68 @@ class ParamOpHook(ABC):
pass
pass
class
ParamOpHookManager
:
class
Colo
ParamOpHookManager
:
"""Manage your param op hooks. It only has static methods.
"""Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``.
The only static method you should call is ``use_hooks(*hooks)``.
"""
"""
hooks
:
Tuple
[
ParamOpHook
,
...]
=
tuple
()
hooks
:
Tuple
[
Colo
ParamOpHook
,
...]
=
tuple
()
@
staticmethod
@
staticmethod
@
contextmanager
@
contextmanager
def
use_hooks
(
*
hooks
:
ParamOpHook
):
def
use_hooks
(
*
hooks
:
Colo
ParamOpHook
):
"""Change the param op hooks you use. Nested calling is allowed.
"""Change the param op hooks you use. Nested calling is allowed.
Example:
Example:
>>> with ParamOpHookManager.use_hooks(*hooks):
>>> with
Colo
ParamOpHookManager.use_hooks(*hooks):
>>> do_something()
>>> do_something()
>>> with ParamOpHookManager.use_hooks():
>>> with
Colo
ParamOpHookManager.use_hooks():
>>> // clear hooks
>>> // clear hooks
>>> do_something()
>>> do_something()
"""
"""
try
:
try
:
old_param_op_hooks
=
ParamOpHookManager
.
hooks
old_param_op_hooks
=
Colo
ParamOpHookManager
.
hooks
ParamOpHookManager
.
hooks
=
hooks
Colo
ParamOpHookManager
.
hooks
=
hooks
yield
yield
finally
:
finally
:
ParamOpHookManager
.
hooks
=
old_param_op_hooks
Colo
ParamOpHookManager
.
hooks
=
old_param_op_hooks
@
staticmethod
@
staticmethod
def
_trigger_pre_forward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
def
_trigger_pre_forward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
for
hook
in
Colo
ParamOpHookManager
.
hooks
:
hook
.
pre_forward
(
params
)
hook
.
pre_forward
(
params
)
@
staticmethod
@
staticmethod
def
_trigger_post_forward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
def
_trigger_post_forward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
for
hook
in
Colo
ParamOpHookManager
.
hooks
:
hook
.
post_forward
(
params
)
hook
.
post_forward
(
params
)
@
staticmethod
@
staticmethod
def
_trigger_pre_backward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
def
_trigger_pre_backward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
for
hook
in
Colo
ParamOpHookManager
.
hooks
:
hook
.
pre_backward
(
params
)
hook
.
pre_backward
(
params
)
@
staticmethod
@
staticmethod
def
_trigger_post_backward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
def
_trigger_post_backward
(
params
:
List
[
torch
.
Tensor
])
->
None
:
for
hook
in
ParamOpHookManager
.
hooks
:
for
hook
in
Colo
ParamOpHookManager
.
hooks
:
hook
.
post_backward
(
params
)
hook
.
post_backward
(
params
)
@
staticmethod
@
staticmethod
def
pre_op
(
params
:
List
[
torch
.
Tensor
],
*
args
:
Any
)
->
list
:
def
pre_op
(
params
:
List
[
torch
.
Tensor
],
*
args
:
Any
)
->
list
:
ParamOpHookManager
.
_trigger_pre_forward
(
params
)
Colo
ParamOpHookManager
.
_trigger_pre_forward
(
params
)
args_info
=
_get_colo_tensors_info
(
*
args
)
args_info
=
_get_colo_tensors_info
(
*
args
)
rets
=
PreFwdPostBwd
.
apply
(
params
,
*
args
)
rets
=
PreFwdPostBwd
.
apply
(
params
,
*
args
)
return
_update_colo_tensors
(
args_info
,
*
rets
)
return
_update_colo_tensors
(
args_info
,
*
rets
)
@
staticmethod
@
staticmethod
def
post_op
(
params
:
List
[
torch
.
Tensor
],
arg
:
Any
)
->
Any
:
def
post_op
(
params
:
List
[
torch
.
Tensor
],
arg
:
Any
)
->
Any
:
ParamOpHookManager
.
_trigger_post_forward
(
params
)
Colo
ParamOpHookManager
.
_trigger_post_forward
(
params
)
arg_info
=
_get_colo_tensors_info
(
arg
)
arg_info
=
_get_colo_tensors_info
(
arg
)
ret
=
PostFwdPreBwd
.
apply
(
params
,
arg
)
ret
=
PostFwdPreBwd
.
apply
(
params
,
arg
)
return
_unpack_args
(
_update_colo_tensors
(
arg_info
,
ret
))
return
_unpack_args
(
_update_colo_tensors
(
arg_info
,
ret
))
@
staticmethod
@
staticmethod
def
has_hook
()
->
bool
:
def
has_hook
()
->
bool
:
return
len
(
ParamOpHookManager
.
hooks
)
>
0
return
len
(
Colo
ParamOpHookManager
.
hooks
)
>
0
class
PreFwdPostBwd
(
torch
.
autograd
.
Function
):
class
PreFwdPostBwd
(
torch
.
autograd
.
Function
):
...
@@ -105,7 +105,7 @@ class PreFwdPostBwd(torch.autograd.Function):
...
@@ -105,7 +105,7 @@ class PreFwdPostBwd(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
grads
):
def
backward
(
ctx
,
*
grads
):
ParamOpHookManager
.
_trigger_post_backward
(
ctx
.
params
)
Colo
ParamOpHookManager
.
_trigger_post_backward
(
ctx
.
params
)
return
(
None
,)
+
grads
return
(
None
,)
+
grads
...
@@ -118,7 +118,7 @@ class PostFwdPreBwd(torch.autograd.Function):
...
@@ -118,7 +118,7 @@ class PostFwdPreBwd(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
grads
):
def
backward
(
ctx
,
*
grads
):
ParamOpHookManager
.
_trigger_pre_backward
(
ctx
.
params
)
Colo
ParamOpHookManager
.
_trigger_pre_backward
(
ctx
.
params
)
return
(
None
,)
+
grads
return
(
None
,)
+
grads
...
...
colossalai/zero/utils/gemini_hook.py
View file @
b3b89865
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
from
colossalai.gemini
import
TensorState
from
colossalai.gemini
import
TensorState
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.tensor.param_op_hook
import
ParamOpHook
from
colossalai.tensor.param_op_hook
import
Colo
ParamOpHook
class
TrainingPhase
(
Enum
):
class
TrainingPhase
(
Enum
):
...
@@ -15,7 +15,7 @@ class TrainingPhase(Enum):
...
@@ -15,7 +15,7 @@ class TrainingPhase(Enum):
BACKWARD
=
1
BACKWARD
=
1
class
GeminiZeROHook
(
ParamOpHook
):
class
GeminiZeROHook
(
Colo
ParamOpHook
):
def
__init__
(
self
,
gemini_manager
:
GeminiManager
)
->
None
:
def
__init__
(
self
,
gemini_manager
:
GeminiManager
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment