Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c11ff81b
Unverified
Commit
c11ff81b
authored
Mar 29, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 29, 2022
Browse files
[zero] get memory usage of sharded optim v2. (#542)
parent
a30e2b4c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
81 additions
and
23 deletions
+81
-23
colossalai/utils/memory_tracer/model_data_memtracer.py
colossalai/utils/memory_tracer/model_data_memtracer.py
+1
-8
colossalai/utils/memory_utils/utils.py
colossalai/utils/memory_utils/utils.py
+21
-1
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+48
-5
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+4
-4
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+7
-5
No files found.
colossalai/utils/memory_tracer/model_data_memtracer.py
View file @
c11ff81b
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.utils.memory_utils.utils
import
colo_tensor_mem_usage
import
torch
import
torch
from
typing
import
Union
,
Tuple
,
Optional
from
typing
import
Union
,
Tuple
,
Optional
from
colossalai.logging
import
DistributedLogger
from
colossalai.logging
import
DistributedLogger
def
_col_tensor_mem_usage
(
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
])
->
int
:
if
isinstance
(
t
,
ShardedTensor
):
target
=
t
.
payload
else
:
target
=
t
return
target
.
numel
()
*
target
.
element_size
()
def
col_model_data_mem_usage
(
model
:
torch
.
nn
.
Module
)
->
Tuple
[
int
,
int
]:
def
col_model_data_mem_usage
(
model
:
torch
.
nn
.
Module
)
->
Tuple
[
int
,
int
]:
"""
"""
Trace the model memory usage.
Trace the model memory usage.
...
...
colossalai/utils/memory_utils/utils.py
View file @
c11ff81b
from
psutil
import
cpu_count
import
torch
import
torch
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
typing
import
Union
from
typing
import
Tuple
,
Union
_GLOBAL_CUDA_MEM_FRACTION
=
1.0
_GLOBAL_CUDA_MEM_FRACTION
=
1.0
def
colo_tensor_mem_usage
(
tensor
:
Union
[
torch
.
Tensor
,
ShardedTensor
])
->
Tuple
[
int
,
int
]:
if
isinstance
(
tensor
,
ShardedTensor
):
t
=
tensor
.
payload
elif
isinstance
(
tensor
,
torch
.
Tensor
):
t
=
tensor
else
:
return
0
,
0
cuda_use
,
cpu_use
=
0
,
0
mem_use
=
t
.
numel
()
*
t
.
element_size
()
if
t
.
device
.
type
==
'cuda'
:
cuda_use
+=
mem_use
elif
t
.
device
.
type
==
'cpu'
:
cpu_use
+=
mem_use
return
cuda_use
,
cpu_use
def
colo_set_process_memory_fraction
(
ratio
:
float
)
->
None
:
def
colo_set_process_memory_fraction
(
ratio
:
float
)
->
None
:
"""colo_set_process_memory_fraction
"""colo_set_process_memory_fraction
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
c11ff81b
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Dict
,
Optional
from
os
import
stat
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -16,7 +17,7 @@ from torch.distributed import ProcessGroup
...
@@ -16,7 +17,7 @@ from torch.distributed import ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_tensor_mem_usage
class
OptimState
(
Enum
):
class
OptimState
(
Enum
):
...
@@ -26,14 +27,20 @@ class OptimState(Enum):
...
@@ -26,14 +27,20 @@ class OptimState(Enum):
class
ShardedOptimizerV2
(
ColossalaiOptimizer
):
class
ShardedOptimizerV2
(
ColossalaiOptimizer
):
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO).
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO).
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.
We apply the Device-aware Operator Placement technique for OS placement from the following paper.
We apply the Device-aware Operator Placement technique for OS placement from the following paper.
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
https://arxiv.org/abs/2108.05818
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
which is detected by a runtime memory tracer.
which is detected by a runtime memory tracer.
We place as many OS chunks in the margin space as possible.
We place as many OS chunks in the margin space as possible.
The size of margin space can be controlled by `gpu_margin_mem_ratio`
The size of margin space can be controlled by `gpu_margin_mem_ratio`。
If it is set as 0.0, it is the same as classical ZeRO optimizer.
If it is set as 0.0, it is the same as classical ZeRO optimizer.
NOTE() You must use `ShardedOptimizerV2` with `ShardedModelV2`.
NOTE() You must use `ShardedOptimizerV2` with `ShardedModelV2`.
...
@@ -99,7 +106,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -99,7 +106,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis
=
hysteresis
,
hysteresis
=
hysteresis
,
max_scale
=
max_scale
)
max_scale
=
max_scale
)
self
.
_found_overflow
:
Tensor
=
torch
.
FloatTensor
([
0
]).
to
(
torch
.
cuda
.
current_device
())
self
.
_found_overflow
:
Tensor
=
torch
.
FloatTensor
([
0
]).
to
(
torch
.
cuda
.
current_device
())
self
.
_logger
=
get_dist_logger
()
self
.
_logger
=
get_dist_logger
(
"ShardedOptimizerV2"
)
# Store fp32 param shards
# Store fp32 param shards
self
.
master_params
:
Dict
[
Parameter
,
Tensor
]
=
{}
self
.
master_params
:
Dict
[
Parameter
,
Tensor
]
=
{}
...
@@ -119,6 +126,37 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -119,6 +126,37 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# So we gather here
# So we gather here
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
_logger
.
debug
(
f
"After init ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
def
get_memory_usage
(
self
)
->
Tuple
[
int
,
int
]:
"""
Get the memory usage of the optimizer. Including master_params (param fp32),
momentum (self.state[p]['exp_avg']) variance (self.state[p]['exp_avg_sq'])
Returns:
Tuple[int, int]: cuda/cpu memory usage in Byte.
"""
cuda_use
=
0
cpu_use
=
0
def
update_mem_use
(
t
):
nonlocal
cuda_use
nonlocal
cpu_use
t_cuda_use
,
t_cpu_use
=
colo_tensor_mem_usage
(
t
)
cuda_use
+=
t_cuda_use
cpu_use
+=
t_cpu_use
for
_
,
p_fp32
in
self
.
master_params
.
items
():
update_mem_use
(
p_fp32
)
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
state
=
self
.
optim
.
state
[
p
]
for
k
,
v
in
state
.
items
():
update_mem_use
(
v
)
return
cuda_use
,
cpu_use
def
step
(
self
,
*
args
,
**
kwargs
):
def
step
(
self
,
*
args
,
**
kwargs
):
self
.
_maybe_move_fp32_shards
()
self
.
_maybe_move_fp32_shards
()
...
@@ -130,7 +168,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -130,7 +168,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
grad_scaler
.
update
(
found_inf
)
self
.
grad_scaler
.
update
(
found_inf
)
if
found_inf
:
if
found_inf
:
self
.
_logger
.
info
(
'found inf during ShardedOptimV2 step'
)
self
.
_logger
.
warning
(
'found inf during ShardedOptimV2 step'
)
self
.
zero_grad
()
self
.
zero_grad
()
return
return
...
@@ -142,8 +180,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -142,8 +180,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Now p.data is sharded
# Now p.data is sharded
# So optimizer states are sharded naturally
# So optimizer states are sharded naturally
self
.
_logger
.
debug
(
f
"Before step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
ret
=
self
.
optim
.
step
(
*
args
,
**
kwargs
)
ret
=
self
.
optim
.
step
(
*
args
,
**
kwargs
)
self
.
_logger
.
debug
(
f
"After step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
# Copy master param data (fp32) to payload of col_attr (fp16)
# Copy master param data (fp32) to payload of col_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
# a chunk.
...
...
colossalai/zero/sharded_param/sharded_param.py
View file @
c11ff81b
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.zero.sharded_param
import
ShardedTensor
from
colossalai.zero.sharded_param
import
ShardedTensor
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
from
colossalai.utils.memory_utils.utils
import
colo_tensor_mem_usage
class
ShardedParamV2
(
object
):
class
ShardedParamV2
(
object
):
...
@@ -55,10 +56,9 @@ class ShardedParamV2(object):
...
@@ -55,10 +56,9 @@ class ShardedParamV2(object):
assert
isinstance
(
t
,
torch
.
Tensor
)
assert
isinstance
(
t
,
torch
.
Tensor
)
nonlocal
cuda_mem_use
nonlocal
cuda_mem_use
nonlocal
cpu_mem_use
nonlocal
cpu_mem_use
if
t
.
device
.
type
==
'cpu'
:
t_cuda
,
t_cpu
=
colo_tensor_mem_usage
(
t
)
cpu_mem_use
+=
t
.
numel
()
*
t
.
element_size
()
cuda_mem_use
+=
t_cuda
elif
t
.
device
.
type
==
'cuda'
:
cpu_mem_use
+=
t_cpu
cuda_mem_use
+=
t
.
numel
()
*
t
.
element_size
()
address_set
=
set
()
address_set
=
set
()
_update_mem_use
(
self
.
sharded_data_tensor
.
payload
)
_update_mem_use
(
self
.
sharded_data_tensor
.
payload
)
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
c11ff81b
from
functools
import
partial
from
functools
import
partial
import
colossalai
import
colossalai
from
colossalai.utils.cuda
import
get_current_device
import
pytest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -57,8 +58,9 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
...
@@ -57,8 +58,9 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
with
ZeroInitContext
(
convert_fp16
=
True
,
with
ZeroInitContext
(
target_device
=
torch
.
device
(
f
'cpu:0'
),
convert_fp16
=
True
,
target_device
=
torch
.
device
(
f
'cpu:0'
)
if
cpu_offload
else
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
),
shard_strategy
=
shard_strategy
,
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
False
):
rm_torch_payload_on_the_fly
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment