Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
705f5610
Unverified
Commit
705f5610
authored
Mar 28, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 28, 2022
Browse files
[zero] refactor model data tracing (#537)
parent
a590ed0b
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
98 additions
and
132 deletions
+98
-132
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+0
-2
colossalai/utils/memory_tracer/memstats_collector.py
colossalai/utils/memory_tracer/memstats_collector.py
+16
-1
colossalai/utils/memory_tracer/model_data_memtracer.py
colossalai/utils/memory_tracer/model_data_memtracer.py
+60
-41
colossalai/utils/memory_tracer/test_memstats_collector.py
colossalai/utils/memory_tracer/test_memstats_collector.py
+1
-4
colossalai/utils/memory_utils/utils.py
colossalai/utils/memory_utils/utils.py
+0
-10
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+0
-12
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
+0
-6
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+0
-5
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+6
-2
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+0
-4
tests/test_utils/test_commons.py
tests/test_utils/test_commons.py
+0
-12
tests/test_utils/test_tensor_move.py
tests/test_utils/test_tensor_move.py
+1
-25
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+14
-8
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
705f5610
...
@@ -5,8 +5,6 @@ import torch.distributed as dist
...
@@ -5,8 +5,6 @@ import torch.distributed as dist
from
colossalai.registry
import
OPHOOKS
from
colossalai.registry
import
OPHOOKS
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
._base_ophook
import
BaseOpHook
from
._base_ophook
import
BaseOpHook
...
...
colossalai/utils/memory_tracer/memstats_collector.py
View file @
705f5610
...
@@ -3,6 +3,7 @@ from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
...
@@ -3,6 +3,7 @@ from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
import
torch
import
torch
from
typing
import
Tuple
class
SamplingCounter
:
class
SamplingCounter
:
...
@@ -40,6 +41,20 @@ class MemStatsCollector:
...
@@ -40,6 +41,20 @@ class MemStatsCollector:
self
.
_start_flag
=
False
self
.
_start_flag
=
False
@
property
def
overall_cuda
(
self
):
return
self
.
_overall_cuda
@
property
def
model_data_cuda
(
self
):
return
self
.
_model_data_cuda
@
property
def
non_model_data_cuda
(
self
):
"""Non model data stats
"""
return
[(
v1
-
v2
)
for
v1
,
v2
in
zip
(
self
.
_overall_cuda
,
self
.
_model_data_cuda
)]
def
start_collection
(
self
):
def
start_collection
(
self
):
self
.
_start_flag
=
True
self
.
_start_flag
=
True
...
@@ -58,7 +73,7 @@ class MemStatsCollector:
...
@@ -58,7 +73,7 @@ class MemStatsCollector:
self
.
_overall_cuda
.
append
(
colo_cuda_memory_used
(
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)))
self
.
_overall_cuda
.
append
(
colo_cuda_memory_used
(
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)))
self
.
_sampling_cnter
.
advance
()
self
.
_sampling_cnter
.
advance
()
def
fetch_memstats
(
self
)
->
(
int
,
int
)
:
def
fetch_memstats
(
self
)
->
Tuple
[
int
,
int
]
:
"""
"""
returns cuda usage of model data and overall cuda usage.
returns cuda usage of model data and overall cuda usage.
"""
"""
...
...
colossalai/utils/memory_tracer/model_data_memtracer.py
View file @
705f5610
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
import
torch
import
torch
from
typing
import
Union
from
typing
import
Union
,
Tuple
,
Optional
from
colossalai.logging
import
DistributedLogger
def
_col_tensor_mem_usage
(
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
])
->
int
:
def
_col_tensor_mem_usage
(
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
])
->
int
:
...
@@ -12,60 +13,78 @@ def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
...
@@ -12,60 +13,78 @@ def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
return
target
.
numel
()
*
target
.
element_size
()
return
target
.
numel
()
*
target
.
element_size
()
class
M
odel
D
ata
Tracer
(
metaclass
=
SingletonMeta
)
:
def
col_m
odel
_d
ata
_mem_usage
(
model
:
torch
.
nn
.
Module
)
->
Tuple
[
int
,
int
]
:
"""
"""
A tracer singleton to trace model data usage during runtime.
Trace the model memory usage.
The tracer is designed to trace the memory layout change during model-data tensors allocation, releasing, and moving.
Args:
To achieve this goal, the developers have to call `ModelDataTracer` in the corresponding code explicitly.
model (torch.nn.Module): a torch model
NOTE() now the class only trace cuda memory usage
Returns:
Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte
"""
"""
def
__init__
(
self
)
->
None
:
def
_get_tensor_mem_use
(
t
:
Optional
[
torch
.
Tensor
]):
self
.
_cuda_usage
=
0
if
t
is
None
:
self
.
_cpu_usage
=
0
return
self
.
_start_flag
=
False
assert
isinstance
(
t
,
torch
.
Tensor
)
_cpu_mem_usage
,
_cuda_mem_usage
=
0
,
0
if
t
.
device
.
type
==
'cpu'
:
_cpu_mem_usage
+=
t
.
numel
()
*
t
.
element_size
()
elif
t
.
device
.
type
==
'cuda'
:
_cuda_mem_usages
+=
t
.
numel
()
*
t
.
element_size
()
return
_cuda_mem_usage
,
_cpu_mem_usage
def
start
(
self
)
->
None
:
cuda_mem_usage
=
0
self
.
_start_flag
=
True
cpu_mem_usage
=
0
for
param
in
model
.
parameters
():
if
hasattr
(
param
,
'col_attr'
):
para_cuda
,
param_cpu
=
param
.
col_attr
.
get_memory_usage
()
cuda_mem_usage
+=
para_cuda
cpu_mem_usage
+=
param_cpu
else
:
t_cuda
,
t_cpu
=
_get_tensor_mem_use
(
param
.
data
)
cuda_mem_usage
+=
t_cuda
cpu_mem_usage
+=
t_cpu
t_cuda
,
t_cpu
=
_get_tensor_mem_use
(
param
.
grad
)
cuda_mem_usage
+=
t_cuda
cpu_mem_usage
+=
t_cpu
def
close
(
self
)
->
None
:
return
cuda_mem_usage
,
cpu_mem_usage
self
.
_start_flag
=
False
def
add_tensor
(
self
,
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
])
->
None
:
if
not
self
.
_start_flag
:
return
t_payload
=
t
.
payload
if
isinstance
(
t
,
ShardedTensor
)
else
t
mem_use
=
_col_tensor_mem_usage
(
t_payload
)
if
t_payload
.
device
.
type
==
'cuda'
:
self
.
_cuda_usage
+=
mem_use
elif
t_payload
.
device
.
type
==
'cpu'
:
self
.
_cpu_usage
+=
mem_use
else
:
raise
TypeError
def
delete_tensor
(
self
,
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
])
->
None
:
class
ModelDataTracer
(
metaclass
=
SingletonMeta
):
if
not
self
.
_start_flag
:
"""
return
A tracer singleton to trace model data usage during runtime.
t_payload
=
t
.
payload
if
isinstance
(
t
,
ShardedTensor
)
else
t
You have to register a model on the singleton first.
mem_use
=
_col_tensor_mem_usage
(
t_payload
)
"""
if
t_payload
.
device
.
type
==
'cuda'
:
self
.
_cuda_usage
-=
mem_use
def
__init__
(
self
)
->
None
:
elif
t_payload
.
device
.
type
==
'cpu'
:
self
.
_logger
=
DistributedLogger
(
"ModelDataTracer"
)
self
.
_cpu_usage
-=
mem_use
self
.
_model
=
None
else
:
raise
TypeError
def
_get_mem_usage
(
self
)
->
Tuple
[
int
,
int
]:
"""
get the memory usage of the model registered.
Returns:
Tuple[int, int]: cuda, cpu mem usage
"""
if
self
.
_model
is
None
:
self
.
_logger
.
warning
(
"The Global ModelDataTracer is using, but no model is registered on it."
)
return
0
,
0
return
col_model_data_mem_usage
(
self
.
_model
)
def
clear
(
self
)
->
None
:
def
register_model
(
self
,
model
)
->
None
:
self
.
_cuda_usage
=
0
self
.
_model
=
model
self
.
_cpu_usage
=
0
@
property
@
property
def
cpu_usage
(
self
):
def
cpu_usage
(
self
):
return
self
.
_cpu_usage
_
,
cpu_usage
=
self
.
_get_mem_usage
()
return
cpu_usage
@
property
@
property
def
cuda_usage
(
self
):
def
cuda_usage
(
self
):
return
self
.
_cuda_usage
cuda_usage
,
_
=
self
.
_get_mem_usage
()
return
cuda_usage
GLOBAL_MODEL_DATA_TRACER
=
ModelDataTracer
()
GLOBAL_MODEL_DATA_TRACER
=
ModelDataTracer
()
colossalai/utils/memory_tracer/test_memstats_collector.py
View file @
705f5610
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
import
torch
import
torch
...
@@ -14,7 +13,6 @@ def test_mem_collector():
...
@@ -14,7 +13,6 @@ def test_mem_collector():
collector
.
sample_memstats
()
collector
.
sample_memstats
()
m_a
=
torch
.
randn
(
10
).
cuda
()
m_a
=
torch
.
randn
(
10
).
cuda
()
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
m_a
)
b
=
torch
.
randn
(
10
).
cuda
()
b
=
torch
.
randn
(
10
).
cuda
()
# sampling at time 1
# sampling at time 1
...
@@ -35,8 +33,7 @@ def test_mem_collector():
...
@@ -35,8 +33,7 @@ def test_mem_collector():
cuda_use
,
overall_use
=
collector
.
fetch_memstats
()
cuda_use
,
overall_use
=
collector
.
fetch_memstats
()
print
(
cuda_use
,
overall_use
)
print
(
cuda_use
,
overall_use
)
print
(
collector
.
_model_data_cuda
)
print
(
collector
.
overall_cuda
)
print
(
collector
.
_overall_cuda
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
colossalai/utils/memory_utils/utils.py
View file @
705f5610
import
torch
import
torch
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
typing
import
Union
from
typing
import
Union
...
@@ -52,9 +51,7 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
...
@@ -52,9 +51,7 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
tgt_t_payload
=
tgt_t
.
data
tgt_t_payload
=
tgt_t
.
data
tgt_dev
=
tgt_t_payload
.
device
tgt_dev
=
tgt_t_payload
.
device
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
src_t_payload
)
tgt_t_payload
.
copy_
(
src_t_payload
)
tgt_t_payload
.
copy_
(
src_t_payload
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
tgt_t_payload
)
# remove payload of src_t
# remove payload of src_t
if
isinstance
(
src_t
,
ShardedTensor
):
if
isinstance
(
src_t
,
ShardedTensor
):
...
@@ -84,11 +81,7 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor],
...
@@ -84,11 +81,7 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor],
# deal with torch.device('cpu') and torch.device('cpu:0)
# deal with torch.device('cpu') and torch.device('cpu:0)
if
t_payload
.
device
.
type
==
target_device
.
type
:
if
t_payload
.
device
.
type
==
target_device
.
type
:
return
return
if
use_tracer
:
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t_payload
)
t_payload
.
data
=
t_payload
.
data
.
to
(
target_device
)
t_payload
.
data
=
t_payload
.
data
.
to
(
target_device
)
if
use_tracer
:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t_payload
)
def
colo_model_data_move_to_cpu
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
])
->
None
:
def
colo_model_data_move_to_cpu
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
])
->
None
:
...
@@ -111,9 +104,7 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
...
@@ -111,9 +104,7 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
return
return
# TODO() optimize the tensor moving with non-blocking
# TODO() optimize the tensor moving with non-blocking
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t_payload
)
t_payload
.
data
=
t_payload
.
data
.
cpu
()
t_payload
.
data
=
t_payload
.
data
.
cpu
()
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t_payload
)
def
colo_model_tensor_clone
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
],
target_device
:
torch
.
device
)
->
torch
.
Tensor
:
def
colo_model_tensor_clone
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
],
target_device
:
torch
.
device
)
->
torch
.
Tensor
:
...
@@ -129,5 +120,4 @@ def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device
...
@@ -129,5 +120,4 @@ def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device
t_payload
=
t
.
payload
if
isinstance
(
t
,
ShardedTensor
)
else
t
t_payload
=
t
.
payload
if
isinstance
(
t
,
ShardedTensor
)
else
t
ret
=
t_payload
.
to
(
target_device
)
ret
=
t_payload
.
to
(
target_device
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
ret
)
return
ret
return
ret
colossalai/zero/init_ctx/init_context.py
View file @
705f5610
...
@@ -4,8 +4,6 @@ from typing import Optional
...
@@ -4,8 +4,6 @@ from typing import Optional
import
torch
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
from
colossalai.utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
...
@@ -130,7 +128,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -130,7 +128,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The Callback function when entering the context
The Callback function when entering the context
"""
"""
self
.
logger
=
get_dist_logger
(
"ZeroInitContext"
)
self
.
logger
=
get_dist_logger
(
"ZeroInitContext"
)
GLOBAL_MODEL_DATA_TRACER
.
start
()
def
_post_context_exec
(
self
):
def
_post_context_exec
(
self
):
"""The callback function when exiting context.
"""The callback function when exiting context.
...
@@ -141,12 +138,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -141,12 +138,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param
.
col_attr
.
remove_torch_payload
()
param
.
col_attr
.
remove_torch_payload
()
del
self
.
initialized_param_list
del
self
.
initialized_param_list
GLOBAL_MODEL_DATA_TRACER
.
close
()
model_data_cuda_mem_MB
=
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
/
1e6
self
.
logger
.
info
(
f
"Existing ZeRO Context.
\n
Model Data CUDA Memory
{
model_data_cuda_mem_MB
}
MB"
,
ranks
=
[
0
])
sys_cuda_mem_MB
=
colo_cuda_memory_used
()
/
1e6
self
.
logger
.
info
(
f
"System CUDA Memory Usage
{
sys_cuda_mem_MB
}
MB"
,
ranks
=
[
0
])
self
.
logger
.
info
(
f
"Model Number Parameter
{
self
.
model_numel_tensor
.
numpy
()[
0
]
/
1e6
}
M"
,
ranks
=
[
0
])
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
):
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
):
"""
"""
...
@@ -176,9 +167,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -176,9 +167,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param
.
col_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
param
.
col_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
self
.
initialized_param_list
.
append
(
param
)
self
.
initialized_param_list
.
append
(
param
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
sharded_data_tensor
)
if
self
.
shard_param
:
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
...
...
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
View file @
705f5610
...
@@ -7,7 +7,6 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
...
@@ -7,7 +7,6 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from
torch._utils
import
_flatten_dense_tensors
as
flatten
from
torch._utils
import
_flatten_dense_tensors
as
flatten
from
.tensor_shard_strategy
import
TensorShardStrategy
from
.tensor_shard_strategy
import
TensorShardStrategy
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
class
BucketTensorShardStrategy
(
TensorShardStrategy
):
class
BucketTensorShardStrategy
(
TensorShardStrategy
):
...
@@ -18,8 +17,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
...
@@ -18,8 +17,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
"""
"""
def
gather
(
self
,
tensor_list
:
List
[
ShardedTensor
],
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
):
def
gather
(
self
,
tensor_list
:
List
[
ShardedTensor
],
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
):
for
t
in
tensor_list
:
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
)
tensor_list
:
List
[
ShardedTensor
]
=
[
t
for
t
in
tensor_list
if
t
.
is_sharded
]
tensor_list
:
List
[
ShardedTensor
]
=
[
t
for
t
in
tensor_list
if
t
.
is_sharded
]
if
len
(
tensor_list
)
==
0
:
if
len
(
tensor_list
)
==
0
:
...
@@ -50,6 +47,3 @@ class BucketTensorShardStrategy(TensorShardStrategy):
...
@@ -50,6 +47,3 @@ class BucketTensorShardStrategy(TensorShardStrategy):
t
.
reset_payload
(
gathered_payload
)
t
.
reset_payload
(
gathered_payload
)
t
.
is_sharded
=
False
t
.
is_sharded
=
False
offset
+=
tensor_numels
[
i
]
offset
+=
tensor_numels
[
i
]
for
t
in
tensor_list
:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t
)
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
705f5610
...
@@ -7,7 +7,6 @@ from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, col
...
@@ -7,7 +7,6 @@ from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, col
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils.commons
import
get_shard
from
colossalai.zero.shard_utils.commons
import
get_shard
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
class
TensorShardStrategy
(
BaseShardStrategy
):
class
TensorShardStrategy
(
BaseShardStrategy
):
...
@@ -36,10 +35,8 @@ class TensorShardStrategy(BaseShardStrategy):
...
@@ -36,10 +35,8 @@ class TensorShardStrategy(BaseShardStrategy):
if
t
.
payload
.
device
.
type
==
'cuda'
:
if
t
.
payload
.
device
.
type
==
'cuda'
:
assert
t
.
payload
.
device
.
index
==
get_current_device
(),
f
"shard tensor on cuda device index
{
t
.
payload
.
device
.
index
}
,"
\
assert
t
.
payload
.
device
.
index
==
get_current_device
(),
f
"shard tensor on cuda device index
{
t
.
payload
.
device
.
index
}
,"
\
f
" but current cuda device is
{
get_current_device
()
}
"
f
" but current cuda device is
{
get_current_device
()
}
"
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
.
payload
)
sharded_payload
,
_
=
get_shard
(
t
.
payload
,
dist
.
get_rank
(
process_group
),
dist
.
get_world_size
(
process_group
))
sharded_payload
,
_
=
get_shard
(
t
.
payload
,
dist
.
get_rank
(
process_group
),
dist
.
get_world_size
(
process_group
))
t
.
reset_payload
(
sharded_payload
)
t
.
reset_payload
(
sharded_payload
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t
.
payload
)
t
.
is_sharded
=
True
t
.
is_sharded
=
True
def
_gather_tensor
(
self
,
t
:
ShardedTensor
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
):
def
_gather_tensor
(
self
,
t
:
ShardedTensor
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
):
...
@@ -56,10 +53,8 @@ class TensorShardStrategy(BaseShardStrategy):
...
@@ -56,10 +53,8 @@ class TensorShardStrategy(BaseShardStrategy):
else
:
else
:
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
,
dtype
=
t
.
dtype
,
device
=
get_current_device
()))
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
,
dtype
=
t
.
dtype
,
device
=
get_current_device
()))
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
.
payload
)
dist
.
all_gather
(
buffer_list
,
buffer_list
[
rank
],
group
=
process_group
,
async_op
=
False
)
dist
.
all_gather
(
buffer_list
,
buffer_list
[
rank
],
group
=
process_group
,
async_op
=
False
)
gathered_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
),
0
,
0
,
t
.
origin_numel
).
reshape
(
t
.
origin_shape
)
gathered_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
),
0
,
0
,
t
.
origin_numel
).
reshape
(
t
.
origin_shape
)
t
.
reset_payload
(
gathered_payload
)
t
.
reset_payload
(
gathered_payload
)
colo_model_data_tensor_move_inline
(
t
,
target_device
,
use_tracer
=
False
)
colo_model_data_tensor_move_inline
(
t
,
target_device
,
use_tracer
=
False
)
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
.
payload
)
t
.
is_sharded
=
False
t
.
is_sharded
=
False
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
705f5610
...
@@ -11,6 +11,7 @@ from colossalai.engine.ophooks import register_ophooks_recursively
...
@@ -11,6 +11,7 @@ from colossalai.engine.ophooks import register_ophooks_recursively
from
colossalai.engine.ophooks.zero_hook
import
ZeroHook
from
colossalai.engine.ophooks.zero_hook
import
ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_model_data_move_to_cpu
,
colo_cuda_memory_capacity
,
colo_model_tensor_clone
from
colossalai.utils.memory_utils.utils
import
colo_model_data_move_to_cpu
,
colo_cuda_memory_capacity
,
colo_model_tensor_clone
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
...
@@ -83,6 +84,7 @@ class ShardedModelV2(nn.Module):
...
@@ -83,6 +84,7 @@ class ShardedModelV2(nn.Module):
# Init Memory Statistics Collector
# Init Memory Statistics Collector
self
.
_use_memory_tracer
=
use_memory_tracer
self
.
_use_memory_tracer
=
use_memory_tracer
if
self
.
_use_memory_tracer
:
if
self
.
_use_memory_tracer
:
GLOBAL_MODEL_DATA_TRACER
.
register_model
(
self
)
self
.
_memstats_collector
=
MemStatsCollector
()
self
.
_memstats_collector
=
MemStatsCollector
()
else
:
else
:
self
.
_memstats_collector
=
None
self
.
_memstats_collector
=
None
...
@@ -147,14 +149,16 @@ class ShardedModelV2(nn.Module):
...
@@ -147,14 +149,16 @@ class ShardedModelV2(nn.Module):
def
_update_memstats
(
self
):
def
_update_memstats
(
self
):
if
self
.
_iter_cnter
==
0
and
self
.
_memstats_collector
:
if
self
.
_iter_cnter
==
0
and
self
.
_memstats_collector
:
self
.
_memstats_collector
.
finish_collection
()
self
.
_memstats_collector
.
finish_collection
()
self
.
logger
.
info
(
f
'model data cuda,
{
self
.
_memstats_collector
.
model_data_cuda
}
'
)
self
.
logger
.
info
(
f
'non-model data cuda,
{
self
.
_memstats_collector
.
non_model_data_cuda
}
'
)
if
self
.
_memstats_collector
:
if
self
.
_memstats_collector
:
self
.
_memstats_collector
.
reset_sampling_cnter
()
self
.
_memstats_collector
.
reset_sampling_cnter
()
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# the way to calculate margin space is based on the assumption that
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
# cuda margin space can be used to store OS.
self
.
_cuda_margin_space
=
colo_cuda_memory_capacity
()
-
max
(
self
.
_memstats_collector
.
_overall_cuda
)
self
.
_cuda_margin_space
=
colo_cuda_memory_capacity
()
-
max
(
self
.
_memstats_collector
.
overall_cuda
)
self
.
_iter_cnter
+=
1
self
.
_iter_cnter
+=
1
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
705f5610
...
@@ -9,7 +9,6 @@ from colossalai.context.parallel_mode import ParallelMode
...
@@ -9,7 +9,6 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -218,9 +217,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -218,9 +217,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We must set grad to None
# We must set grad to None
# Because we will judge whether local grad accumulation
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
# is enabled by wheter grad is None
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
p
.
grad
)
self
.
optim
.
zero_grad
(
set_to_none
=
True
)
self
.
optim
.
zero_grad
(
set_to_none
=
True
)
def
sync_grad
(
self
):
def
sync_grad
(
self
):
...
...
tests/test_utils/test_commons.py
View file @
705f5610
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_on_exception
from
colossalai.testing
import
rerun_on_exception
...
@@ -13,22 +12,15 @@ import torch.multiprocessing as mp
...
@@ -13,22 +12,15 @@ import torch.multiprocessing as mp
def
run_tensor_move
(
rank
):
def
run_tensor_move
(
rank
):
colossalai
.
launch
(
config
=
{},
rank
=
0
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
0
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
)
GLOBAL_MODEL_DATA_TRACER
.
start
()
src_t
=
torch
.
ones
(
2
,
3
).
cuda
()
src_t
=
torch
.
ones
(
2
,
3
).
cuda
()
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
src_t
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
24
)
tgt_t
=
torch
.
zeros
(
2
,
3
)
tgt_t
=
torch
.
zeros
(
2
,
3
)
colo_model_data_tensor_move
(
src_t
,
tgt_t
)
colo_model_data_tensor_move
(
src_t
,
tgt_t
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
)
assert
(
torch
.
sum
(
tgt_t
)
==
6.0
),
f
"
{
torch
.
sum
(
tgt_t
.
payload
)
}
vs. 6.0"
assert
(
torch
.
sum
(
tgt_t
)
==
6.0
),
f
"
{
torch
.
sum
(
tgt_t
.
payload
)
}
vs. 6.0"
src_t
=
torch
.
ones
(
2
,
3
)
src_t
=
torch
.
ones
(
2
,
3
)
tgt_t
=
torch
.
zeros
(
2
,
3
).
cuda
().
half
()
tgt_t
=
torch
.
zeros
(
2
,
3
).
cuda
().
half
()
colo_model_data_tensor_move
(
src_t
,
tgt_t
)
colo_model_data_tensor_move
(
src_t
,
tgt_t
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
12
),
f
"cuda usage
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
"
# the src_t has been removed
# the src_t has been removed
assert
(
src_t
.
numel
()
==
0
)
assert
(
src_t
.
numel
()
==
0
)
assert
(
torch
.
sum
(
tgt_t
)
==
6.0
),
f
"
{
torch
.
sum
(
tgt_t
.
payload
)
}
vs. 6.0"
assert
(
torch
.
sum
(
tgt_t
)
==
6.0
),
f
"
{
torch
.
sum
(
tgt_t
.
payload
)
}
vs. 6.0"
...
@@ -36,15 +28,11 @@ def run_tensor_move(rank):
...
@@ -36,15 +28,11 @@ def run_tensor_move(rank):
src_t
=
ShardedTensor
(
torch
.
ones
(
2
,
3
))
src_t
=
ShardedTensor
(
torch
.
ones
(
2
,
3
))
tgt_t
=
ShardedTensor
(
torch
.
zeros
(
2
,
3
).
cuda
().
half
())
tgt_t
=
ShardedTensor
(
torch
.
zeros
(
2
,
3
).
cuda
().
half
())
colo_model_data_tensor_move
(
src_t
,
tgt_t
)
colo_model_data_tensor_move
(
src_t
,
tgt_t
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
24
),
f
"cuda usage
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
"
assert
(
torch
.
sum
(
tgt_t
.
payload
)
==
6.0
),
f
"
{
torch
.
sum
(
tgt_t
.
payload
)
}
vs. 6.0"
assert
(
torch
.
sum
(
tgt_t
.
payload
)
==
6.0
),
f
"
{
torch
.
sum
(
tgt_t
.
payload
)
}
vs. 6.0"
assert
(
tgt_t
.
device
.
type
==
'cuda'
)
assert
(
tgt_t
.
device
.
type
==
'cuda'
)
colo_model_data_tensor_move_inline
(
tgt_t
,
torch
.
device
(
'cpu'
))
colo_model_data_tensor_move_inline
(
tgt_t
,
torch
.
device
(
'cpu'
))
assert
(
tgt_t
.
device
.
type
==
'cpu'
)
assert
(
tgt_t
.
device
.
type
==
'cpu'
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
12
),
f
"cuda usage
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
"
GLOBAL_MODEL_DATA_TRACER
.
close
()
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
...
...
tests/test_utils/test_tensor_move.py
View file @
705f5610
import
pytest
import
pytest
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
from
colossalai.utils
import
free_port
from
colossalai.zero.sharded_param
import
ShardedTensor
from
colossalai.zero.sharded_param
import
ShardedTensor
import
colossalai
import
colossalai
import
torch
import
torch
from
functools
import
partial
from
functools
import
partial
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
def
_run_colo_model_data_tensor_move_inline
():
def
_run_colo_model_data_tensor_move_inline
():
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
)
GLOBAL_MODEL_DATA_TRACER
.
start
()
for
t
in
[
torch
.
randn
(
2
,
3
),
ShardedTensor
(
torch
.
randn
(
2
,
3
))]:
for
t
in
[
torch
.
randn
(
2
,
3
),
ShardedTensor
(
torch
.
randn
(
2
,
3
))]:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t
)
assert
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
==
2
*
3
*
4
assert
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
colo_model_data_tensor_move_inline
(
t
,
torch
.
device
(
f
"cuda:
{
get_current_device
()
}
"
))
colo_model_data_tensor_move_inline
(
t
,
torch
.
device
(
f
"cuda:
{
get_current_device
()
}
"
))
assert
t
.
device
==
torch
.
device
(
f
"cuda:
{
get_current_device
()
}
"
)
assert
t
.
device
==
torch
.
device
(
f
"cuda:
{
get_current_device
()
}
"
)
assert
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
==
0
assert
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
2
*
3
*
4
GLOBAL_MODEL_DATA_TRACER
.
clear
()
GLOBAL_MODEL_DATA_TRACER
.
close
()
def
_run_colo_model_data_tensor_move
():
def
_run_colo_model_data_tensor_move
():
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
)
GLOBAL_MODEL_DATA_TRACER
.
start
()
for
t
in
[(
torch
.
ones
(
2
,
3
),
torch
.
zeros
(
2
,
3
).
cuda
(
get_current_device
())),
for
t
in
[(
torch
.
ones
(
2
,
3
),
torch
.
zeros
(
2
,
3
).
cuda
(
get_current_device
())),
(
ShardedTensor
(
torch
.
ones
(
2
,
3
)),
ShardedTensor
(
torch
.
zeros
(
2
,
3
).
cuda
(
get_current_device
())))]:
(
ShardedTensor
(
torch
.
ones
(
2
,
3
)),
ShardedTensor
(
torch
.
zeros
(
2
,
3
).
cuda
(
get_current_device
())))]:
cpu_t
,
cuda_t
=
t
cpu_t
,
cuda_t
=
t
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
cpu_t
)
assert
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
==
2
*
3
*
4
assert
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
colo_model_data_tensor_move
(
cpu_t
,
cuda_t
)
colo_model_data_tensor_move
(
cpu_t
,
cuda_t
)
assert
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
==
0
assert
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
2
*
3
*
4
GLOBAL_MODEL_DATA_TRACER
.
clear
()
GLOBAL_MODEL_DATA_TRACER
.
close
()
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
705f5610
...
@@ -10,19 +10,21 @@ import torch.multiprocessing as mp
...
@@ -10,19 +10,21 @@ import torch.multiprocessing as mp
from
colossalai.testing
import
parameterize
from
colossalai.testing
import
parameterize
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
from
colossalai.utils.memory_tracer.model_data_memtracer
import
col_model_data_mem_usage
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.testing
import
rerun_on_exception
from
colossalai.testing
import
rerun_on_exception
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.logging
import
get_dist_logger
from
common
import
CONFIG
from
common
import
CONFIG
@
parameterize
(
"init_device_type"
,
[
'cpu'
,
'cuda'
])
@
parameterize
(
"init_device_type"
,
[
'cpu'
,
'cuda'
])
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
run_model_test
(
init_device_type
,
shard_strategy_class
):
def
run_model_test
(
init_device_type
,
shard_strategy_class
):
logger
=
get_dist_logger
(
"test_zero_init"
)
for
get_components_func
in
non_distributed_component_funcs
:
for
get_components_func
in
non_distributed_component_funcs
:
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
...
@@ -32,6 +34,8 @@ def run_model_test(init_device_type, shard_strategy_class):
...
@@ -32,6 +34,8 @@ def run_model_test(init_device_type, shard_strategy_class):
init_device
=
torch
.
device
(
"cpu"
)
init_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
continue
continue
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
with
ZeroInitContext
(
convert_fp16
=
True
,
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
init_device
,
target_device
=
init_device
,
shard_strategy
=
shard_strategy_class
(),
shard_strategy
=
shard_strategy_class
(),
...
@@ -46,11 +50,13 @@ def run_model_test(init_device_type, shard_strategy_class):
...
@@ -46,11 +50,13 @@ def run_model_test(init_device_type, shard_strategy_class):
assert
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
assert
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
f
'
{
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
if
init_device
.
type
==
'cuda'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
cuda_mem_use
,
cpu_mem_use
=
col_model_data_mem_usage
(
model
)
else
:
model_data_cuda_mem_MB
=
cuda_mem_use
/
1e6
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
>
0
)
logger
.
info
(
f
"Existing ZeRO Context.
\n
Model Data CUDA Memory
{
model_data_cuda_mem_MB
}
MB"
,
ranks
=
[
0
])
GLOBAL_MODEL_DATA_TRACER
.
clear
()
sys_cuda_mem_MB
=
colo_cuda_memory_used
()
/
1e6
logger
.
info
(
f
"System CUDA Memory Usage
{
sys_cuda_mem_MB
}
MB"
,
ranks
=
[
0
])
logger
.
info
(
f
"Model Number Parameter
{
model_numel_tensor
.
numpy
()[
0
]
/
1e6
}
M"
,
ranks
=
[
0
])
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment