Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8d8c5407
Unverified
Commit
8d8c5407
authored
Mar 25, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 25, 2022
Browse files
[zero] refactor model data tracing (#522)
parent
3601b2ba
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
128 additions
and
28 deletions
+128
-28
colossalai/utils/memory_tracer/model_data_memtracer.py
colossalai/utils/memory_tracer/model_data_memtracer.py
+20
-8
colossalai/utils/memory_utils/utils.py
colossalai/utils/memory_utils/utils.py
+10
-11
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+3
-5
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
+7
-0
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+16
-2
colossalai/zero/sharded_param/sharded_tensor.py
colossalai/zero/sharded_param/sharded_tensor.py
+4
-1
tests/test_utils/test_tensor_move.py
tests/test_utils/test_tensor_move.py
+66
-0
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+2
-1
No files found.
colossalai/utils/memory_tracer/model_data_memtracer.py
View file @
8d8c5407
...
...
@@ -22,6 +22,7 @@ class ModelDataTracer(metaclass=SingletonMeta):
def
__init__
(
self
)
->
None
:
self
.
_cuda_usage
=
0
self
.
_cpu_usage
=
0
self
.
_start_flag
=
False
def
start
(
self
)
->
None
:
...
...
@@ -30,22 +31,33 @@ class ModelDataTracer(metaclass=SingletonMeta):
def
close
(
self
)
->
None
:
self
.
_start_flag
=
False
def
add_tensor
(
self
,
t
:
torch
.
Tensor
)
->
None
:
def
add_tensor
(
self
,
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
]
)
->
None
:
if
not
self
.
_start_flag
:
return
assert
isinstance
(
t
,
torch
.
Tensor
),
f
"ModelDataTracer add_tensor() should accept a torch.Tensor"
mem_use
=
_col_tensor_mem_usage
(
t
)
t_payload
=
t
.
payload
if
isinstance
(
t
,
ShardedTensor
)
else
t
mem_use
=
_col_tensor_mem_usage
(
t_payload
)
if
t_payload
.
device
.
type
==
'cuda'
:
self
.
_cuda_usage
+=
mem_use
elif
t_payload
.
device
.
type
==
'cpu'
:
self
.
_cpu_usage
+=
mem_use
else
:
raise
TypeError
def
delete_tensor
(
self
,
t
:
torch
.
Tensor
)
->
None
:
def
delete_tensor
(
self
,
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
]
)
->
None
:
if
not
self
.
_start_flag
:
return
assert
isinstance
(
t
,
torch
.
Tensor
),
f
"ModelDataTracer delete_tensor() should accept a torch.Tensor"
mem_use
=
_col_tensor_mem_usage
(
t
)
t_payload
=
t
.
payload
if
isinstance
(
t
,
ShardedTensor
)
else
t
mem_use
=
_col_tensor_mem_usage
(
t_payload
)
if
t_payload
.
device
.
type
==
'cuda'
:
self
.
_cuda_usage
-=
mem_use
elif
t_payload
.
device
.
type
==
'cpu'
:
self
.
_cpu_usage
-=
mem_use
else
:
raise
TypeError
def
clear
(
self
)
->
None
:
self
.
_cuda_usage
=
0
self
.
_cpu_usage
=
0
@
property
def
cpu_usage
(
self
):
...
...
colossalai/utils/memory_utils/utils.py
View file @
8d8c5407
...
...
@@ -3,7 +3,7 @@ from colossalai.utils import get_current_device
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
typing
import
Union
,
Optional
from
typing
import
Union
_GLOBAL_CUDA_MEM_FRACTION
=
1.0
...
...
@@ -52,11 +52,9 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
tgt_t_payload
=
tgt_t
.
data
tgt_dev
=
tgt_t_payload
.
device
if
src_dev
.
type
==
'cuda'
and
tgt_dev
.
type
==
'cpu'
:
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
src_t_payload
)
elif
src_dev
.
type
==
'cpu'
and
tgt_dev
.
type
==
'cuda'
:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
tgt_t_payload
)
tgt_t_payload
.
copy_
(
src_t_payload
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
tgt_t_payload
)
# remove payload of src_t
if
isinstance
(
src_t
,
ShardedTensor
):
...
...
@@ -65,7 +63,9 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
src_t
.
data
=
torch
.
tensor
([],
device
=
src_dev
,
dtype
=
src_t_payload
.
dtype
)
def
colo_model_data_tensor_move_inline
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
],
target_device
:
torch
.
device
)
->
None
:
def
colo_model_data_tensor_move_inline
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
],
target_device
:
torch
.
device
,
use_tracer
:
bool
=
True
)
->
None
:
"""
move a tensor to the target_device
Args:
...
...
@@ -84,13 +84,11 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], ta
# deal with torch.device('cpu') and torch.device('cpu:0)
if
t_payload
.
device
.
type
==
target_device
.
type
:
return
if
target_device
.
type
==
'cuda'
:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t_payload
)
elif
target_device
.
type
==
'cpu'
:
if
use_tracer
:
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t_payload
)
t_payload
.
data
=
t_payload
.
data
.
to
(
target_device
)
if
use_tracer
:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t_payload
)
def
colo_model_data_move_to_cpu
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
])
->
None
:
...
...
@@ -115,3 +113,4 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
# TODO() optimize the tensor moving with non-blocking
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t_payload
)
t_payload
.
data
=
t_payload
.
data
.
cpu
()
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t_payload
)
colossalai/zero/init_ctx/init_context.py
View file @
8d8c5407
...
...
@@ -177,13 +177,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self
.
initialized_param_list
.
append
(
param
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
sharded_data_tensor
)
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
if
param
.
col_attr
.
sharded_data_tensor
.
device
.
type
==
'cuda'
:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
sharded_data_tensor
.
payload
)
# if param.col_attr.grad and self.shard_grad:
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# We must cast them
...
...
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
View file @
8d8c5407
...
...
@@ -7,6 +7,7 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from
torch._utils
import
_flatten_dense_tensors
as
flatten
from
.tensor_shard_strategy
import
TensorShardStrategy
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
class
BucketTensorShardStrategy
(
TensorShardStrategy
):
...
...
@@ -17,6 +18,9 @@ class BucketTensorShardStrategy(TensorShardStrategy):
"""
def
gather
(
self
,
tensor_list
:
List
[
ShardedTensor
],
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
):
for
t
in
tensor_list
:
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
)
tensor_list
:
List
[
ShardedTensor
]
=
[
t
for
t
in
tensor_list
if
t
.
is_sharded
]
if
len
(
tensor_list
)
==
0
:
return
...
...
@@ -46,3 +50,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
t
.
reset_payload
(
gathered_payload
)
t
.
is_sharded
=
False
offset
+=
tensor_numels
[
i
]
for
t
in
tensor_list
:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t
)
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
8d8c5407
...
...
@@ -3,13 +3,16 @@ from typing import List, Optional
import
torch
import
torch.distributed
as
dist
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils.commons
import
get_shard
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
class
TensorShardStrategy
(
BaseShardStrategy
):
"""A naive implementation which shard each tensor evenly over all ranks
"""
A naive implementation which shard each tensor evenly over all ranks
"""
def
shard
(
self
,
tensor_list
:
List
[
ShardedTensor
],
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
):
...
...
@@ -21,13 +24,22 @@ class TensorShardStrategy(BaseShardStrategy):
self
.
_gather_tensor
(
t
,
process_group
)
def
_shard_tensor
(
self
,
t
:
ShardedTensor
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
):
""" Shard tensor among processes.
Args:
t (ShardedTensor): a tensor to be sharded.
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
Defaults to None.
"""
if
t
.
is_sharded
:
return
if
t
.
payload
.
device
.
type
==
'cuda'
:
assert
t
.
payload
.
device
.
index
==
get_current_device
(),
f
"shard tensor on cuda device index
{
t
.
payload
.
device
.
index
}
,"
\
f
" but current cuda device is
{
get_current_device
()
}
"
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
.
payload
)
sharded_payload
,
_
=
get_shard
(
t
.
payload
,
dist
.
get_rank
(
process_group
),
dist
.
get_world_size
(
process_group
))
t
.
reset_payload
(
sharded_payload
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t
.
payload
)
t
.
is_sharded
=
True
def
_gather_tensor
(
self
,
t
:
ShardedTensor
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
):
...
...
@@ -44,8 +56,10 @@ class TensorShardStrategy(BaseShardStrategy):
else
:
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
,
dtype
=
t
.
dtype
,
device
=
get_current_device
()))
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
.
payload
)
dist
.
all_gather
(
buffer_list
,
buffer_list
[
rank
],
group
=
process_group
,
async_op
=
False
)
gathered_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
),
0
,
0
,
t
.
origin_numel
).
reshape
(
t
.
origin_shape
)
t
.
reset_payload
(
gathered_payload
)
t
.
to
(
target_device
)
colo_model_data_tensor_move_inline
(
t
,
target_device
,
use_tracer
=
False
)
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
.
payload
)
t
.
is_sharded
=
False
colossalai/zero/sharded_param/sharded_tensor.py
View file @
8d8c5407
...
...
@@ -56,7 +56,10 @@ class ShardedTensor(object):
return
self
.
_origin_dtype
def
to
(
self
,
device
:
torch
.
device
):
self
.
_payload
=
self
.
_payload
.
to
(
device
)
raise
RuntimeError
(
"Use colo_model_tensor_move install of call .to() on ShardedTensor"
)
def
to_
(
self
,
device
:
torch
.
device
):
raise
RuntimeError
(
"Use colo_model_tensor_move install of call .to_() on ShardedTensor"
)
@
property
def
shape
(
self
):
...
...
tests/test_utils/test_tensor_move.py
0 → 100644
View file @
8d8c5407
import
pytest
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
from
colossalai.zero.sharded_param
import
ShardedTensor
import
colossalai
import
torch
from
functools
import
partial
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
def
_run_colo_model_data_tensor_move_inline
():
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
)
GLOBAL_MODEL_DATA_TRACER
.
start
()
for
t
in
[
torch
.
randn
(
2
,
3
),
ShardedTensor
(
torch
.
randn
(
2
,
3
))]:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t
)
assert
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
==
2
*
3
*
4
assert
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
colo_model_data_tensor_move_inline
(
t
,
torch
.
device
(
f
"cuda:
{
get_current_device
()
}
"
))
assert
t
.
device
==
torch
.
device
(
f
"cuda:
{
get_current_device
()
}
"
)
assert
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
==
0
assert
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
2
*
3
*
4
GLOBAL_MODEL_DATA_TRACER
.
clear
()
GLOBAL_MODEL_DATA_TRACER
.
close
()
def
_run_colo_model_data_tensor_move
():
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
)
GLOBAL_MODEL_DATA_TRACER
.
start
()
for
t
in
[(
torch
.
ones
(
2
,
3
),
torch
.
zeros
(
2
,
3
).
cuda
(
get_current_device
())),
(
ShardedTensor
(
torch
.
ones
(
2
,
3
)),
ShardedTensor
(
torch
.
zeros
(
2
,
3
).
cuda
(
get_current_device
())))]:
cpu_t
,
cuda_t
=
t
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
cpu_t
)
assert
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
==
2
*
3
*
4
assert
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
0
colo_model_data_tensor_move
(
cpu_t
,
cuda_t
)
assert
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
==
0
assert
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
==
2
*
3
*
4
GLOBAL_MODEL_DATA_TRACER
.
clear
()
GLOBAL_MODEL_DATA_TRACER
.
close
()
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
_run_colo_model_data_tensor_move_inline
()
_run_colo_model_data_tensor_move
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
4
])
def
test_tensor_move
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_tensor_move
(
4
)
tests/test_zero_data_parallel/test_init_context.py
View file @
8d8c5407
...
...
@@ -48,6 +48,8 @@ def run_model_test(init_device_type, shard_strategy_class):
f
'
{
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
if
init_device
.
type
==
'cuda'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
else
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
>
0
)
GLOBAL_MODEL_DATA_TRACER
.
clear
()
...
...
@@ -65,5 +67,4 @@ def test_zero_init_context(world_size):
if
__name__
==
'__main__'
:
# test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
test_zero_init_context
(
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment