Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
425b4a96
Unverified
Commit
425b4a96
authored
Apr 26, 2022
by
HELSON
Committed by
GitHub
Apr 26, 2022
Browse files
[gemini] polish stateful_tensor_mgr (#876)
parent
e43f83aa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
21 deletions
+22
-21
colossalai/gemini/stateful_tensor_mgr.py
colossalai/gemini/stateful_tensor_mgr.py
+16
-17
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+5
-3
colossalai/zero/utils/zero_hook.py
colossalai/zero/utils/zero_hook.py
+1
-1
No files found.
colossalai/gemini/stateful_tensor_mgr.py
View file @
425b4a96
...
@@ -6,7 +6,6 @@ from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, c
...
@@ -6,7 +6,6 @@ from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, c
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
,
TensorState
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
,
TensorState
from
colossalai.gemini.tensor_placement_policy
import
TensorPlacementPolicy
from
colossalai.gemini.tensor_placement_policy
import
TensorPlacementPolicy
from
typing
import
List
from
typing
import
List
from
colossalai.logging
import
get_dist_logger
class
StatefulTensorMgr
(
object
):
class
StatefulTensorMgr
(
object
):
...
@@ -20,23 +19,30 @@ class StatefulTensorMgr(object):
...
@@ -20,23 +19,30 @@ class StatefulTensorMgr(object):
def
__init__
(
self
,
tensor_placement_policy
:
TensorPlacementPolicy
)
->
None
:
def
__init__
(
self
,
tensor_placement_policy
:
TensorPlacementPolicy
)
->
None
:
self
.
_tensor_placement_policy
:
TensorPlacementPolicy
=
tensor_placement_policy
self
.
_tensor_placement_policy
:
TensorPlacementPolicy
=
tensor_placement_policy
self
.
_stateful_tensor_list
:
List
[
StatefulTensor
]
=
[]
self
.
_stateful_tensor_list
:
List
[
StatefulTensor
]
=
[]
self
.
_logger
=
get_dist_logger
(
"StatefulTensorMgr"
)
self
.
_warmup
=
True
self
.
_compute_list
:
List
[
StatefulTensor
]
=
[]
self
.
_compute_list
:
List
[
StatefulTensor
]
=
[]
self
.
_compute_idx
:
int
=
-
1
self
.
_compute_idx
:
int
=
-
1
self
.
_cpu_gpu_move_volume
=
0
self
.
_cpu_gpu_move_volume
=
0
self
.
_warmup
=
True
def
register_stateful_
param
(
self
,
param
)
->
None
:
def
register_stateful_
tensor_list
(
self
,
tensor_list
:
List
[
StatefulTensor
]
)
->
None
:
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
assert
self
.
_stateful_tensor_list
==
[],
"Can't register stateful tensors for manager twice"
assert
isinstance
(
param
,
ShardedParamV2
)
self
.
_stateful_tensor_list
=
tensor_list
for
t
in
param
.
get_payload
_tensor
s
()
:
for
t
in
self
.
_stateful
_tensor
_list
:
assert
isinstance
(
t
,
StatefulTensor
)
assert
isinstance
(
t
,
StatefulTensor
)
self
.
_stateful_tensor_list
.
append
(
t
)
t
.
trans_state
=
types
.
MethodType
(
functools
.
partial
(
self
.
_trans_state
,
t
.
trans_state
),
t
)
t
.
trans_state
=
types
.
MethodType
(
functools
.
partial
(
self
.
_trans_state
,
t
.
trans_state
),
t
)
def
start_iter
(
self
):
pass
def
finish_iter
(
self
):
"""This function must be called when each iteration finishes
"""
self
.
_warmup
=
False
self
.
_compute_idx
=
-
1
self
.
_cpu_gpu_move_volume
=
0
def
adjust_layout
(
self
)
->
None
:
def
adjust_layout
(
self
)
->
None
:
""" Adjust the layout of statefuil tensor according to the information provided
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
by mem_stats_collector, which should belongs to a Sharded Model.
...
@@ -63,21 +69,14 @@ class StatefulTensorMgr(object):
...
@@ -63,21 +69,14 @@ class StatefulTensorMgr(object):
compute_list
=
self
.
_compute_list
,
compute_list
=
self
.
_compute_list
,
compute_idx
=
self
.
_compute_idx
)
compute_idx
=
self
.
_compute_idx
)
# move COMPUTE tensors to CUDA
# move COMPUTE tensors to CUDA
self
.
_cpu_gpu_move_volume
+=
cuda_demand
for
t
in
move_to_cuda_tensor_list
:
for
t
in
move_to_cuda_tensor_list
:
colo_model_data_tensor_move_inline
(
t
,
get_current_device
())
colo_model_data_tensor_move_inline
(
t
,
get_current_device
())
self
.
_cpu_gpu_move_volume
+=
t
.
payload_size
@
property
@
property
def
cpu_gpu_move_volume
(
self
):
def
cpu_gpu_move_volume
(
self
):
return
self
.
_cpu_gpu_move_volume
return
self
.
_cpu_gpu_move_volume
def
reset
(
self
):
"""This function must be called when each iteration finishes
"""
self
.
_warmup
=
False
self
.
_compute_idx
=
-
1
self
.
_cpu_gpu_move_volume
=
0
def
_trans_state
(
self
,
trans_state_func
,
stateful_tensor
,
state
):
def
_trans_state
(
self
,
trans_state_func
,
stateful_tensor
,
state
):
trans_state_func
(
state
)
trans_state_func
(
state
)
if
state
==
TensorState
.
COMPUTE
:
if
state
==
TensorState
.
COMPUTE
:
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
425b4a96
...
@@ -111,10 +111,10 @@ class ShardedModelV2(nn.Module):
...
@@ -111,10 +111,10 @@ class ShardedModelV2(nn.Module):
self
.
_memstats_collector
=
None
self
.
_memstats_collector
=
None
self
.
_tensor_placement_policy
:
TensorPlacementPolicy
=
TensorPlacementPolicyFactory
.
create
(
self
.
_tensor_placement_policy
:
TensorPlacementPolicy
=
TensorPlacementPolicyFactory
.
create
(
tensor_placement_policy
)(
mem_stats_collector
=
self
.
_memstats_collector
)
tensor_placement_policy
)(
mem_stats_collector
=
self
.
_memstats_collector
)
self
.
_stateful_tensor_mgr
=
StatefulTensorMgr
(
self
.
_tensor_placement_policy
)
self
.
_stateful_tensor_mgr
=
StatefulTensorMgr
(
self
.
_tensor_placement_policy
)
for
param
in
module
.
parameters
():
param_tensor_list
=
[
p
.
colo_attr
.
sharded_data_tensor
for
p
in
module
.
parameters
()
if
hasattr
(
p
,
'colo_attr'
)]
if
hasattr
(
param
,
'colo_attr'
):
self
.
_stateful_tensor_mgr
.
register_stateful_tensor_list
(
param_tensor_list
)
self
.
_stateful_tensor_mgr
.
register_stateful_param
(
param
.
colo_attr
)
# Register hooks
# Register hooks
self
.
_ophook_list
=
[
self
.
_ophook_list
=
[
...
@@ -198,6 +198,8 @@ class ShardedModelV2(nn.Module):
...
@@ -198,6 +198,8 @@ class ShardedModelV2(nn.Module):
if
hasattr
(
p
,
'colo_attr'
):
if
hasattr
(
p
,
'colo_attr'
):
p
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
p
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
self
.
_stateful_tensor_mgr
.
start_iter
()
def
_post_forward_operations
(
self
):
def
_post_forward_operations
(
self
):
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
if
hasattr
(
p
,
'colo_attr'
):
if
hasattr
(
p
,
'colo_attr'
):
...
...
colossalai/zero/utils/zero_hook.py
View file @
425b4a96
...
@@ -115,4 +115,4 @@ class ZeroHook(BaseOpHook):
...
@@ -115,4 +115,4 @@ class ZeroHook(BaseOpHook):
if
self
.
_stateful_tensor_mgr
:
if
self
.
_stateful_tensor_mgr
:
self
.
logger
.
info
(
self
.
logger
.
info
(
f
"CPU-GPU data moving this iteration
{
self
.
_stateful_tensor_mgr
.
cpu_gpu_move_volume
/
1e9
}
GB"
,
ranks
=
[
0
])
f
"CPU-GPU data moving this iteration
{
self
.
_stateful_tensor_mgr
.
cpu_gpu_move_volume
/
1e9
}
GB"
,
ranks
=
[
0
])
self
.
_stateful_tensor_mgr
.
reset
()
self
.
_stateful_tensor_mgr
.
finish_iter
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment