Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
705f5610
Unverified
Commit
705f5610
authored
Mar 28, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 28, 2022
Browse files
[zero] refactor model data tracing (#537)
parent
a590ed0b
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
98 additions
and
132 deletions
+98
-132
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+0
-2
colossalai/utils/memory_tracer/memstats_collector.py
colossalai/utils/memory_tracer/memstats_collector.py
+16
-1
colossalai/utils/memory_tracer/model_data_memtracer.py
colossalai/utils/memory_tracer/model_data_memtracer.py
+60
-41
colossalai/utils/memory_tracer/test_memstats_collector.py
colossalai/utils/memory_tracer/test_memstats_collector.py
+1
-4
colossalai/utils/memory_utils/utils.py
colossalai/utils/memory_utils/utils.py
+0
-10
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+0
-12
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
+0
-6
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+0
-5
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+6
-2
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+0
-4
tests/test_utils/test_commons.py
tests/test_utils/test_commons.py
+0
-12
tests/test_utils/test_tensor_move.py
tests/test_utils/test_tensor_move.py
+1
-25
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+14
-8
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
705f5610
...
...
@@ -5,8 +5,6 @@ import torch.distributed as dist
from
colossalai.registry
import
OPHOOKS
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
._base_ophook
import
BaseOpHook
...
...
colossalai/utils/memory_tracer/memstats_collector.py
View file @
705f5610
...
...
@@ -3,6 +3,7 @@ from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from
colossalai.utils
import
get_current_device
import
torch
from
typing
import
Tuple
class
SamplingCounter
:
...
...
@@ -40,6 +41,20 @@ class MemStatsCollector:
self
.
_start_flag
=
False
@
property
def
overall_cuda
(
self
):
return
self
.
_overall_cuda
@
property
def
model_data_cuda
(
self
):
return
self
.
_model_data_cuda
@
property
def
non_model_data_cuda
(
self
):
"""Non model data stats
"""
return
[(
v1
-
v2
)
for
v1
,
v2
in
zip
(
self
.
_overall_cuda
,
self
.
_model_data_cuda
)]
def
start_collection
(
self
):
self
.
_start_flag
=
True
...
...
@@ -58,7 +73,7 @@ class MemStatsCollector:
self
.
_overall_cuda
.
append
(
colo_cuda_memory_used
(
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)))
self
.
_sampling_cnter
.
advance
()
def
fetch_memstats
(
self
)
->
(
int
,
int
)
:
def
fetch_memstats
(
self
)
->
Tuple
[
int
,
int
]
:
"""
returns cuda usage of model data and overall cuda usage.
"""
...
...
colossalai/utils/memory_tracer/model_data_memtracer.py
View file @
705f5610
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
import
torch
from
typing
import
Union
from
typing
import
Union
,
Tuple
,
Optional
from
colossalai.logging
import
DistributedLogger
def
_col_tensor_mem_usage
(
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
])
->
int
:
...
...
@@ -12,60 +13,78 @@ def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
return
target
.
numel
()
*
target
.
element_size
()
def
col_model_data_mem_usage
(
model
:
torch
.
nn
.
Module
)
->
Tuple
[
int
,
int
]:
"""
Trace the model memory usage.
Args:
model (torch.nn.Module): a torch model
Returns:
Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte
"""
def
_get_tensor_mem_use
(
t
:
Optional
[
torch
.
Tensor
]):
if
t
is
None
:
return
assert
isinstance
(
t
,
torch
.
Tensor
)
_cpu_mem_usage
,
_cuda_mem_usage
=
0
,
0
if
t
.
device
.
type
==
'cpu'
:
_cpu_mem_usage
+=
t
.
numel
()
*
t
.
element_size
()
elif
t
.
device
.
type
==
'cuda'
:
_cuda_mem_usages
+=
t
.
numel
()
*
t
.
element_size
()
return
_cuda_mem_usage
,
_cpu_mem_usage
cuda_mem_usage
=
0
cpu_mem_usage
=
0
for
param
in
model
.
parameters
():
if
hasattr
(
param
,
'col_attr'
):
para_cuda
,
param_cpu
=
param
.
col_attr
.
get_memory_usage
()
cuda_mem_usage
+=
para_cuda
cpu_mem_usage
+=
param_cpu
else
:
t_cuda
,
t_cpu
=
_get_tensor_mem_use
(
param
.
data
)
cuda_mem_usage
+=
t_cuda
cpu_mem_usage
+=
t_cpu
t_cuda
,
t_cpu
=
_get_tensor_mem_use
(
param
.
grad
)
cuda_mem_usage
+=
t_cuda
cpu_mem_usage
+=
t_cpu
return
cuda_mem_usage
,
cpu_mem_usage
class
ModelDataTracer
(
metaclass
=
SingletonMeta
):
"""
A tracer singleton to trace model data usage during runtime.
The tracer is designed to trace the memory layout change during model-data tensors allocation, releasing, and moving.
To achieve this goal, the developers have to call `ModelDataTracer` in the corresponding code explicitly.
NOTE() now the class only trace cuda memory usage
You have to register a model on the singleton first.
"""
def
__init__
(
self
)
->
None
:
self
.
_cuda_usage
=
0
self
.
_cpu_usage
=
0
self
.
_start_flag
=
False
def
start
(
self
)
->
None
:
self
.
_start_flag
=
True
def
close
(
self
)
->
None
:
self
.
_start_flag
=
False
def
add_tensor
(
self
,
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
])
->
None
:
if
not
self
.
_start_flag
:
return
t_payload
=
t
.
payload
if
isinstance
(
t
,
ShardedTensor
)
else
t
mem_use
=
_col_tensor_mem_usage
(
t_payload
)
if
t_payload
.
device
.
type
==
'cuda'
:
self
.
_cuda_usage
+=
mem_use
elif
t_payload
.
device
.
type
==
'cpu'
:
self
.
_cpu_usage
+=
mem_use
else
:
raise
TypeError
self
.
_logger
=
DistributedLogger
(
"ModelDataTracer"
)
self
.
_model
=
None
def
delete_tensor
(
self
,
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
])
->
None
:
if
not
self
.
_start_flag
:
return
t_payload
=
t
.
payload
if
isinstance
(
t
,
ShardedTensor
)
else
t
mem_use
=
_col_tensor_mem_usage
(
t_payload
)
if
t_payload
.
device
.
type
==
'cuda'
:
self
.
_cuda_usage
-=
mem_use
elif
t_payload
.
device
.
type
==
'cpu'
:
self
.
_cpu_usage
-=
mem_use
else
:
raise
TypeError
def
_get_mem_usage
(
self
)
->
Tuple
[
int
,
int
]:
"""
get the memory usage of the model registered.
Returns:
Tuple[int, int]: cuda, cpu mem usage
"""
if
self
.
_model
is
None
:
self
.
_logger
.
warning
(
"The Global ModelDataTracer is using, but no model is registered on it."
)
return
0
,
0
return
col_model_data_mem_usage
(
self
.
_model
)
def
clear
(
self
)
->
None
:
self
.
_cuda_usage
=
0
self
.
_cpu_usage
=
0
def
register_model
(
self
,
model
)
->
None
:
self
.
_model
=
model
@
property
def
cpu_usage
(
self
):
return
self
.
_cpu_usage
_
,
cpu_usage
=
self
.
_get_mem_usage
()
return
cpu_usage
@
property
def
cuda_usage
(
self
):
return
self
.
_cuda_usage
cuda_usage
,
_
=
self
.
_get_mem_usage
()
return
cuda_usage
GLOBAL_MODEL_DATA_TRACER
=
ModelDataTracer
()
colossalai/utils/memory_tracer/test_memstats_collector.py
View file @
705f5610
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
import
torch
...
...
@@ -14,7 +13,6 @@ def test_mem_collector():
collector
.
sample_memstats
()
m_a
=
torch
.
randn
(
10
).
cuda
()
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
m_a
)
b
=
torch
.
randn
(
10
).
cuda
()
# sampling at time 1
...
...
@@ -35,8 +33,7 @@ def test_mem_collector():
cuda_use
,
overall_use
=
collector
.
fetch_memstats
()
print
(
cuda_use
,
overall_use
)
print
(
collector
.
_model_data_cuda
)
print
(
collector
.
_overall_cuda
)
print
(
collector
.
overall_cuda
)
if
__name__
==
'__main__'
:
...
...
colossalai/utils/memory_utils/utils.py
View file @
705f5610
import
torch
from
colossalai.utils
import
get_current_device
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
typing
import
Union
...
...
@@ -52,9 +51,7 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
tgt_t_payload
=
tgt_t
.
data
tgt_dev
=
tgt_t_payload
.
device
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
src_t_payload
)
tgt_t_payload
.
copy_
(
src_t_payload
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
tgt_t_payload
)
# remove payload of src_t
if
isinstance
(
src_t
,
ShardedTensor
):
...
...
@@ -84,11 +81,7 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor],
# deal with torch.device('cpu') and torch.device('cpu:0)
if
t_payload
.
device
.
type
==
target_device
.
type
:
return
if
use_tracer
:
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t_payload
)
t_payload
.
data
=
t_payload
.
data
.
to
(
target_device
)
if
use_tracer
:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t_payload
)
def
colo_model_data_move_to_cpu
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
])
->
None
:
...
...
@@ -111,9 +104,7 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
return
# TODO() optimize the tensor moving with non-blocking
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t_payload
)
t_payload
.
data
=
t_payload
.
data
.
cpu
()
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t_payload
)
def
colo_model_tensor_clone
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
],
target_device
:
torch
.
device
)
->
torch
.
Tensor
:
...
...
@@ -129,5 +120,4 @@ def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device
t_payload
=
t
.
payload
if
isinstance
(
t
,
ShardedTensor
)
else
t
ret
=
t_payload
.
to
(
target_device
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
ret
)
return
ret
colossalai/zero/init_ctx/init_context.py
View file @
705f5610
...
...
@@ -4,8 +4,6 @@ from typing import Optional
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
...
...
@@ -130,7 +128,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The Callback function when entering the context
"""
self
.
logger
=
get_dist_logger
(
"ZeroInitContext"
)
GLOBAL_MODEL_DATA_TRACER
.
start
()
def
_post_context_exec
(
self
):
"""The callback function when exiting context.
...
...
@@ -141,12 +138,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param
.
col_attr
.
remove_torch_payload
()
del
self
.
initialized_param_list
GLOBAL_MODEL_DATA_TRACER
.
close
()
model_data_cuda_mem_MB
=
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
/
1e6
self
.
logger
.
info
(
f
"Existing ZeRO Context.
\n
Model Data CUDA Memory
{
model_data_cuda_mem_MB
}
MB"
,
ranks
=
[
0
])
sys_cuda_mem_MB
=
colo_cuda_memory_used
()
/
1e6
self
.
logger
.
info
(
f
"System CUDA Memory Usage
{
sys_cuda_mem_MB
}
MB"
,
ranks
=
[
0
])
self
.
logger
.
info
(
f
"Model Number Parameter
{
self
.
model_numel_tensor
.
numpy
()[
0
]
/
1e6
}
M"
,
ranks
=
[
0
])
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
):
"""
...
...
@@ -176,9 +167,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param
.
col_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
self
.
initialized_param_list
.
append
(
param
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
sharded_data_tensor
)
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
...
...
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
View file @
705f5610
...
...
@@ -7,7 +7,6 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from
torch._utils
import
_flatten_dense_tensors
as
flatten
from
.tensor_shard_strategy
import
TensorShardStrategy
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
class
BucketTensorShardStrategy
(
TensorShardStrategy
):
...
...
@@ -18,8 +17,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
"""
def
gather
(
self
,
tensor_list
:
List
[
ShardedTensor
],
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
):
for
t
in
tensor_list
:
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
)
tensor_list
:
List
[
ShardedTensor
]
=
[
t
for
t
in
tensor_list
if
t
.
is_sharded
]
if
len
(
tensor_list
)
==
0
:
...
...
@@ -50,6 +47,3 @@ class BucketTensorShardStrategy(TensorShardStrategy):
t
.
reset_payload
(
gathered_payload
)
t
.
is_sharded
=
False
offset
+=
tensor_numels
[
i
]
for
t
in
tensor_list
:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t
)
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
705f5610
...
...
@@ -7,7 +7,6 @@ from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, col
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils.commons
import
get_shard
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
class
TensorShardStrategy
(
BaseShardStrategy
):
...
...
@@ -36,10 +35,8 @@ class TensorShardStrategy(BaseShardStrategy):
if
t
.
payload
.
device
.
type
==
'cuda'
:
assert
t
.
payload
.
device
.
index
==
get_current_device
(),
f
"shard tensor on cuda device index
{
t
.
payload
.
device
.
index
}
,"
\
f
" but current cuda device is
{
get_current_device
()
}
"
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
.
payload
)
sharded_payload
,
_
=
get_shard
(
t
.
payload
,
dist
.
get_rank
(
process_group
),
dist
.
get_world_size
(
process_group
))
t
.
reset_payload
(
sharded_payload
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t
.
payload
)
t
.
is_sharded
=
True
def
_gather_tensor
(
self
,
t
:
ShardedTensor
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
):
...
...
@@ -56,10 +53,8 @@ class TensorShardStrategy(BaseShardStrategy):
else
:
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
,
dtype
=
t
.
dtype
,
device
=
get_current_device
()))
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
.
payload
)
dist
.
all_gather
(
buffer_list
,
buffer_list
[
rank
],
group
=
process_group
,
async_op
=
False
)
gathered_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
),
0
,
0
,
t
.
origin_numel
).
reshape
(
t
.
origin_shape
)
t
.
reset_payload
(
gathered_payload
)
colo_model_data_tensor_move_inline
(
t
,
target_device
,
use_tracer
=
False
)
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
.
payload
)
t
.
is_sharded
=
False
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
705f5610
...
...
@@ -11,6 +11,7 @@ from colossalai.engine.ophooks import register_ophooks_recursively
from
colossalai.engine.ophooks.zero_hook
import
ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_model_data_move_to_cpu
,
colo_cuda_memory_capacity
,
colo_model_tensor_clone
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.zero.shard_utils
import
BaseShardStrategy
...
...
@@ -83,6 +84,7 @@ class ShardedModelV2(nn.Module):
# Init Memory Statistics Collector
self
.
_use_memory_tracer
=
use_memory_tracer
if
self
.
_use_memory_tracer
:
GLOBAL_MODEL_DATA_TRACER
.
register_model
(
self
)
self
.
_memstats_collector
=
MemStatsCollector
()
else
:
self
.
_memstats_collector
=
None
...
...
@@ -147,14 +149,16 @@ class ShardedModelV2(nn.Module):
def
_update_memstats
(
self
):
if
self
.
_iter_cnter
==
0
and
self
.
_memstats_collector
:
self
.
_memstats_collector
.
finish_collection
()
self
.
logger
.
info
(
f
'model data cuda,
{
self
.
_memstats_collector
.
model_data_cuda
}
'
)
self
.
logger
.
info
(
f
'non-model data cuda,
{
self
.
_memstats_collector
.
non_model_data_cuda
}
'
)
if
self
.
_memstats_collector
:
self
.
_memstats_collector
.
reset_sampling_cnter
()
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self
.
_cuda_margin_space
=
colo_cuda_memory_capacity
()
-
max
(
self
.
_memstats_collector
.
_overall_cuda
)
self
.
_cuda_margin_space
=
colo_cuda_memory_capacity
()
-
max
(
self
.
_memstats_collector
.
overall_cuda
)
self
.
_iter_cnter
+=
1
@
torch
.
no_grad
()
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
705f5610
...
...
@@ -9,7 +9,6 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
torch
import
Tensor
...
...
@@ -218,9 +217,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We must set grad to None
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
p
.
grad
)
self
.
optim
.
zero_grad
(
set_to_none
=
True
)
def
sync_grad
(
self
):
...
...
tests/test_utils/test_commons.py
View file @
705f5610
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_on_exception
...
...
@@ -13,22 +12,15 @@ import torch.multiprocessing as mp
def
run_tensor_move
(
rank
):
colossalai
.
launch
(
config
=
{},
rank
=
0
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
)
GLOBAL_MODEL_DATA_TRACER
.
start
()
src_t
=
torch
.
ones
(
2
,
3
).
cuda
()
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
src_t
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
24
)
tgt_t
=
torch
.
zeros
(
2
,
3
)
colo_model_data_tensor_move
(
src_t
,
tgt_t
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
)
assert
(
torch
.
sum
(
tgt_t
)
==
6.0
),
f
"
{
torch
.
sum
(
tgt_t
.
payload
)
}
vs. 6.0"
src_t
=
torch
.
ones
(
2
,
3
)
tgt_t
=
torch
.
zeros
(
2
,
3
).
cuda
().
half
()
colo_model_data_tensor_move
(
src_t
,
tgt_t
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
12
),
f
"cuda usage
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
"
# the src_t has been removed
assert
(
src_t
.
numel
()
==
0
)
assert
(
torch
.
sum
(
tgt_t
)
==
6.0
),
f
"
{
torch
.
sum
(
tgt_t
.
payload
)
}
vs. 6.0"
...
...
@@ -36,15 +28,11 @@ def run_tensor_move(rank):
src_t
=
ShardedTensor
(
torch
.
ones
(
2
,
3
))
tgt_t
=
ShardedTensor
(
torch
.
zeros
(
2
,
3
).
cuda
().
half
())
colo_model_data_tensor_move
(
src_t
,
tgt_t
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
24
),
f
"cuda usage
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
"
assert
(
torch
.
sum
(
tgt_t
.
payload
)
==
6.0
),
f
"
{
torch
.
sum
(
tgt_t
.
payload
)
}
vs. 6.0"
assert
(
tgt_t
.
device
.
type
==
'cuda'
)
colo_model_data_tensor_move_inline
(
tgt_t
,
torch
.
device
(
'cpu'
))
assert
(
tgt_t
.
device
.
type
==
'cpu'
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
12
),
f
"cuda usage
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
"
GLOBAL_MODEL_DATA_TRACER
.
close
()
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
...
...
tests/test_utils/test_tensor_move.py
View file @
705f5610
import
pytest
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
from
colossalai.utils
import
free_port
from
colossalai.zero.sharded_param
import
ShardedTensor
import
colossalai
import
torch
from
functools
import
partial
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
def
_run_colo_model_data_tensor_move_inline
():
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
)
GLOBAL_MODEL_DATA_TRACER
.
start
()
for
t
in
[
torch
.
randn
(
2
,
3
),
ShardedTensor
(
torch
.
randn
(
2
,
3
))]:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t
)
assert
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
==
2
*
3
*
4
assert
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
colo_model_data_tensor_move_inline
(
t
,
torch
.
device
(
f
"cuda:
{
get_current_device
()
}
"
))
assert
t
.
device
==
torch
.
device
(
f
"cuda:
{
get_current_device
()
}
"
)
assert
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
==
0
assert
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
2
*
3
*
4
GLOBAL_MODEL_DATA_TRACER
.
clear
()
GLOBAL_MODEL_DATA_TRACER
.
close
()
def
_run_colo_model_data_tensor_move
():
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
)
GLOBAL_MODEL_DATA_TRACER
.
start
()
for
t
in
[(
torch
.
ones
(
2
,
3
),
torch
.
zeros
(
2
,
3
).
cuda
(
get_current_device
())),
(
ShardedTensor
(
torch
.
ones
(
2
,
3
)),
ShardedTensor
(
torch
.
zeros
(
2
,
3
).
cuda
(
get_current_device
())))]:
cpu_t
,
cuda_t
=
t
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
cpu_t
)
assert
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
==
2
*
3
*
4
assert
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
colo_model_data_tensor_move
(
cpu_t
,
cuda_t
)
assert
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
==
0
assert
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
2
*
3
*
4
GLOBAL_MODEL_DATA_TRACER
.
clear
()
GLOBAL_MODEL_DATA_TRACER
.
close
()
def
run_dist
(
rank
,
world_size
,
port
):
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
705f5610
...
...
@@ -10,19 +10,21 @@ import torch.multiprocessing as mp
from
colossalai.testing
import
parameterize
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_tracer.model_data_memtracer
import
col_model_data_mem_usage
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.testing
import
rerun_on_exception
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.logging
import
get_dist_logger
from
common
import
CONFIG
@
parameterize
(
"init_device_type"
,
[
'cpu'
,
'cuda'
])
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
run_model_test
(
init_device_type
,
shard_strategy_class
):
logger
=
get_dist_logger
(
"test_zero_init"
)
for
get_components_func
in
non_distributed_component_funcs
:
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
...
...
@@ -32,6 +34,8 @@ def run_model_test(init_device_type, shard_strategy_class):
init_device
=
torch
.
device
(
"cpu"
)
else
:
continue
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
init_device
,
shard_strategy
=
shard_strategy_class
(),
...
...
@@ -46,11 +50,13 @@ def run_model_test(init_device_type, shard_strategy_class):
assert
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
if
init_device
.
type
==
'cuda'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
else
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
>
0
)
GLOBAL_MODEL_DATA_TRACER
.
clear
()
cuda_mem_use
,
cpu_mem_use
=
col_model_data_mem_usage
(
model
)
model_data_cuda_mem_MB
=
cuda_mem_use
/
1e6
logger
.
info
(
f
"Existing ZeRO Context.
\n
Model Data CUDA Memory
{
model_data_cuda_mem_MB
}
MB"
,
ranks
=
[
0
])
sys_cuda_mem_MB
=
colo_cuda_memory_used
()
/
1e6
logger
.
info
(
f
"System CUDA Memory Usage
{
sys_cuda_mem_MB
}
MB"
,
ranks
=
[
0
])
logger
.
info
(
f
"Model Number Parameter
{
model_numel_tensor
.
numpy
()[
0
]
/
1e6
}
M"
,
ranks
=
[
0
])
def
run_dist
(
rank
,
world_size
,
port
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment