Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0bebda6e
Unverified
Commit
0bebda6e
authored
Mar 25, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 25, 2022
Browse files
[zero] fix init device bug in zero init context unittest (#516)
parent
a5131643
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
55 additions
and
37 deletions
+55
-37
colossalai/utils/memory_tracer/async_memtracer.py
colossalai/utils/memory_tracer/async_memtracer.py
+3
-19
colossalai/utils/memory_tracer/memstats_collector.py
colossalai/utils/memory_tracer/memstats_collector.py
+2
-2
colossalai/utils/memory_tracer/model_data_memtracer.py
colossalai/utils/memory_tracer/model_data_memtracer.py
+3
-0
colossalai/utils/memory_utils/memory_monitor.py
colossalai/utils/memory_utils/memory_monitor.py
+22
-0
colossalai/utils/memory_utils/utils.py
colossalai/utils/memory_utils/utils.py
+1
-1
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+8
-7
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+3
-0
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+13
-8
No files found.
colossalai/utils/memory_tracer/async_memtracer.py
View file @
0bebda6e
...
...
@@ -2,26 +2,10 @@ from concurrent.futures import ThreadPoolExecutor
from
time
import
sleep
,
time
import
pickle
from
colossalai.utils
import
get_current_device
import
torch
def
get_cuda_memory_used
(
device
:
torch
.
device
)
->
int
:
"""
Get the free memory info of device.
:param device: device id
:type device: torch.device
:return: current memory usage, sized by MB
:rtype: int
"""
assert
device
.
type
==
'cuda'
ret
:
int
=
torch
.
cuda
.
memory_allocated
(
device
)
# get the peak memory to report correct data, so reset the counter for the next call
if
hasattr
(
torch
.
cuda
,
"reset_peak_memory_stats"
):
# pytorch 1.4+
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
return
ret
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
class
AsyncMemoryMonitor
:
...
...
@@ -97,7 +81,7 @@ class AsyncMemoryMonitor:
while
self
.
keep_measuring
:
max_usage
=
max
(
max_usage
,
get
_cuda_memory_used
(
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)
),
colo
_cuda_memory_used
(),
)
sleep
(
self
.
interval
)
return
max_usage
...
...
colossalai/utils/memory_tracer/memstats_collector.py
View file @
0bebda6e
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
.async_memtrace
r
import
get
_cuda_memory_used
from
colossalai.utils.memory_utils.memory_monito
r
import
colo
_cuda_memory_used
from
colossalai.utils
import
get_current_device
import
torch
...
...
@@ -55,7 +55,7 @@ class MemStatsCollector:
sampling_cnt
=
self
.
_sampling_cnter
.
sampling_cnt
assert
sampling_cnt
==
len
(
self
.
_overall_cuda
)
self
.
_model_data_cuda
.
append
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
)
self
.
_overall_cuda
.
append
(
get
_cuda_memory_used
(
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)))
self
.
_overall_cuda
.
append
(
colo
_cuda_memory_used
(
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)))
self
.
_sampling_cnter
.
advance
()
def
fetch_memstats
(
self
)
->
(
int
,
int
):
...
...
colossalai/utils/memory_tracer/model_data_memtracer.py
View file @
0bebda6e
...
...
@@ -44,6 +44,9 @@ class ModelDataTracer(metaclass=SingletonMeta):
mem_use
=
_col_tensor_mem_usage
(
t
)
self
.
_cuda_usage
-=
mem_use
def
clear
(
self
)
->
None
:
self
.
_cuda_usage
=
0
@
property
def
cpu_usage
(
self
):
return
self
.
_cpu_usage
...
...
colossalai/utils/memory_utils/memory_monitor.py
View file @
0bebda6e
...
...
@@ -9,6 +9,28 @@ import torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils.cuda
import
get_current_device
from
typing
import
Optional
def
colo_cuda_memory_used
(
device
:
Optional
[
torch
.
device
]
=
None
)
->
int
:
"""
Get the free memory info of device.
:param device: a torch device instance or None
:type device: Optional[torch.device]
:return: current memory usage, sized by Byte
:rtype: int
"""
if
device
:
assert
device
.
type
==
'cuda'
else
:
device
=
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)
ret
:
int
=
torch
.
cuda
.
memory_allocated
(
device
)
# get the peak memory to report correct data, so reset the counter for the next call
if
hasattr
(
torch
.
cuda
,
"reset_peak_memory_stats"
):
# pytorch 1.4+
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
return
ret
def
bytes_to_GB
(
val
,
decimal
=
2
):
...
...
colossalai/utils/memory_utils/utils.py
View file @
0bebda6e
...
...
@@ -3,7 +3,7 @@ from colossalai.utils import get_current_device
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
typing
import
Union
from
typing
import
Union
,
Optional
_GLOBAL_CUDA_MEM_FRACTION
=
1.0
...
...
colossalai/zero/init_ctx/init_context.py
View file @
0bebda6e
...
...
@@ -6,16 +6,14 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
class
InsertPostInitMethodToModuleSubClasses
(
object
):
def
__init__
(
self
):
...
...
@@ -144,8 +142,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
del
self
.
initialized_param_list
GLOBAL_MODEL_DATA_TRACER
.
close
()
cuda_mem_MB
=
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
/
1e6
self
.
logger
.
info
(
f
"Existing ZeRO Context Model Data CUDA Memory Usage
{
cuda_mem_MB
}
MB"
,
[
0
])
model_data_cuda_mem_MB
=
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
/
1e6
self
.
logger
.
info
(
f
"Existing ZeRO Context: Model Data CUDA Memory
{
model_data_cuda_mem_MB
}
MB"
,
ranks
=
[
0
])
sys_cuda_mem_MB
=
colo_cuda_memory_used
()
/
1e6
self
.
logger
.
info
(
f
"System CUDA Memory Usage
{
sys_cuda_mem_MB
}
MB"
,
ranks
=
[
0
])
self
.
logger
.
info
(
f
"Model Number Parameter
{
self
.
model_numel_tensor
.
numpy
()[
0
]
/
1e6
}
M"
,
ranks
=
[
0
])
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
):
"""
...
...
@@ -178,8 +179,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
if
param
.
col_attr
.
sharded_data_tensor
.
device
.
type
==
'cuda'
:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
sharded_data_tensor
.
payload
)
if
param
.
col_attr
.
sharded_data_tensor
.
device
.
type
==
'cuda'
:
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
sharded_data_tensor
.
payload
)
# if param.col_attr.grad and self.shard_grad:
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
...
...
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
0bebda6e
...
...
@@ -23,6 +23,9 @@ class TensorShardStrategy(BaseShardStrategy):
def
_shard_tensor
(
self
,
t
:
ShardedTensor
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
):
if
t
.
is_sharded
:
return
if
t
.
payload
.
device
.
type
==
'cuda'
:
assert
t
.
payload
.
device
.
index
==
get_current_device
(),
f
"shard tensor on cuda device index
{
t
.
payload
.
device
.
index
}
,"
\
f
" but current cuda device is
{
get_current_device
()
}
"
sharded_payload
,
_
=
get_shard
(
t
.
payload
,
dist
.
get_rank
(
process_group
),
dist
.
get_world_size
(
process_group
))
t
.
reset_payload
(
sharded_payload
)
t
.
is_sharded
=
True
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
0bebda6e
...
...
@@ -19,17 +19,24 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from
common
import
CONFIG
@
parameterize
(
"init_device
"
,
[
torch
.
device
(
'cpu'
)
,
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)
])
@
parameterize
(
"init_device
_type"
,
[
'cpu'
,
'cuda'
])
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
run_model_test
(
init_device
,
shard_strategy_class
):
def
run_model_test
(
init_device
_type
,
shard_strategy_class
):
for
get_components_func
in
non_distributed_component_funcs
:
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
if
init_device_type
==
'cuda'
:
init_device
=
torch
.
device
(
f
"cuda:
{
get_current_device
()
}
"
)
elif
init_device_type
==
'cpu'
:
init_device
=
torch
.
device
(
"cpu"
)
else
:
continue
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
init_device
,
shard_strategy
=
shard_strategy_class
(),
shard_param
=
True
,
model_numel_tensor
=
model_numel_tensor
):
model_numel_tensor
=
model_numel_tensor
,
rm_torch_payload_on_the_fly
=
False
):
model
=
model_builder
(
checkpoint
=
True
)
for
param
in
model
.
parameters
():
...
...
@@ -38,11 +45,9 @@ def run_model_test(init_device, shard_strategy_class):
assert
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
print
(
f
'cuda usgae
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
'
)
print
(
f
'numel
{
model_numel_tensor
}
'
)
if
init_device
.
type
==
'cuda'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
if
init_device
.
type
==
'cuda'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
GLOBAL_MODEL_DATA_TRACER
.
clear
()
def
run_dist
(
rank
,
world_size
,
port
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment