Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
715b86ea
Unverified
Commit
715b86ea
authored
Apr 11, 2022
by
ver217
Committed by
GitHub
Apr 11, 2022
Browse files
[hotfix] fix stm cuda model data size (#710)
parent
140263a3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
colossalai/zero/shard_utils/stateful_tensor_mgr.py
colossalai/zero/shard_utils/stateful_tensor_mgr.py
+2
-2
No files found.
colossalai/zero/shard_utils/stateful_tensor_mgr.py
View file @
715b86ea
...
@@ -6,6 +6,7 @@ from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
...
@@ -6,6 +6,7 @@ from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
,
TensorState
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
,
TensorState
from
colossalai.zero.shard_utils.tensor_utils
import
colo_model_data_tensor_move_inline
,
colo_tensor_mem_usage
from
colossalai.zero.shard_utils.tensor_utils
import
colo_model_data_tensor_move_inline
,
colo_tensor_mem_usage
from
colossalai.utils.memory_utils.utils
import
colo_cuda_memory_capacity
from
colossalai.utils.memory_utils.utils
import
colo_cuda_memory_capacity
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
from
colossalai.utils.memory_tracer
import
MemStatsCollector
from
colossalai.utils.memory_tracer
import
MemStatsCollector
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
...
@@ -48,14 +49,13 @@ class StatefulTensorMgr(object):
...
@@ -48,14 +49,13 @@ class StatefulTensorMgr(object):
# find stateful tensor in state COMPUTE
# find stateful tensor in state COMPUTE
move_to_cuda_tensor_list
=
[]
move_to_cuda_tensor_list
=
[]
cuda_demand
=
0
cuda_demand
=
0
used_cuda_model_data
=
0
used_cuda_model_data
=
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
hold_cuda_tensor_list
=
[]
hold_cuda_tensor_list
=
[]
for
tensor
in
self
.
_stateful_tensor_list
:
for
tensor
in
self
.
_stateful_tensor_list
:
if
tensor
.
state
==
TensorState
.
FREE
:
if
tensor
.
state
==
TensorState
.
FREE
:
continue
continue
if
tensor
.
device
.
type
==
'cuda'
:
if
tensor
.
device
.
type
==
'cuda'
:
used_cuda_model_data
+=
colo_tensor_mem_usage
(
tensor
.
payload
)[
0
]
if
tensor
.
state
in
[
TensorState
.
HOLD
,
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
HOLD_AFTER_FWD
]:
if
tensor
.
state
in
[
TensorState
.
HOLD
,
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
HOLD_AFTER_FWD
]:
hold_cuda_tensor_list
.
append
(
tensor
)
hold_cuda_tensor_list
.
append
(
tensor
)
elif
tensor
.
device
.
type
==
'cpu'
:
elif
tensor
.
device
.
type
==
'cpu'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment