Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7c6c427d
Unverified
Commit
7c6c427d
authored
Mar 31, 2022
by
ver217
Committed by
GitHub
Mar 31, 2022
Browse files
[zero] trace states of fp16/32 grad and fp32 param (#571)
parent
7675366f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
69 additions
and
73 deletions
+69
-73
colossalai/utils/memory_utils/utils.py
colossalai/utils/memory_utils/utils.py
+3
-3
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+14
-23
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+52
-36
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+0
-5
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+0
-6
No files found.
colossalai/utils/memory_utils/utils.py
View file @
7c6c427d
...
...
@@ -93,7 +93,7 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
raise
TypeError
(
'colo_model_data_move_to_cpu dose not accept type {type(t)}'
)
if
isinstance
(
target_device
,
int
):
target_device
=
torch
.
cuda
(
f
'device"
{
target_device
}
'
)
target_device
=
torch
.
device
(
f
'cuda:
{
target_device
}
'
)
# deal with torch.device('cpu') and torch.device('cpu:0)
if
t_payload
.
device
.
type
==
target_device
.
type
:
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
7c6c427d
...
...
@@ -18,7 +18,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
from
colossalai.utils.memory_utils.utils
import
(
colo_cuda_memory_capacity
,
colo_model_data_move_to_cpu
)
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_param.tensorful_state
import
(
StatefulTensor
,
TensorState
)
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
...
...
@@ -245,27 +245,7 @@ class ShardedModelV2(nn.Module):
# We also allows to interleave no-sync pass with sync passes, if desired.
if
not
self
.
_require_backward_grad_sync
:
continue
# Reduced grad is saved in `p.colo_attr.saved_grad`
# It can be on CPU or CUDA
# It can be fp16 or fp32
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
if
self
.
reuse_fp16_shard
:
grad_fp16_payload
=
p
.
colo_attr
.
sharded_data_tensor
.
payload
else
:
grad_fp16_payload
=
cast_tensor_to_fp32
(
p
.
colo_attr
.
fp16_grad
.
payload
)
assert
isinstance
(
grad_fp16_payload
,
torch
.
Tensor
)
if
p
.
colo_attr
.
offload_grad
:
colo_model_data_move_to_cpu
(
grad_fp16_payload
)
if
not
p
.
colo_attr
.
saved_grad
.
is_null
():
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
# Accumulate grad, saved grad must be fp32
p
.
colo_attr
.
saved_grad
.
reset_payload
(
cast_tensor_to_fp32
(
p
.
colo_attr
.
saved_grad
.
payload
))
p
.
colo_attr
.
saved_grad
.
payload
.
add_
(
grad_fp16_payload
.
view_as
(
p
.
colo_attr
.
saved_grad
.
payload
))
else
:
p
.
colo_attr
.
saved_grad
.
reset_payload
(
grad_fp16_payload
)
p
.
grad
=
None
p
.
colo_attr
.
fp16_grad
.
set_null
()
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
...
@@ -322,11 +302,22 @@ class ShardedModelV2(nn.Module):
if
self
.
gradient_postdivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
# FIXME(ver217): remove the below line when impl eviction policy
if
param
.
colo_attr
.
offload_grad
:
colo_model_data_move_to_cpu
(
reduced_grad
)
if
self
.
reuse_fp16_shard
:
param
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
reduced_grad
.
data
)
assert
param
.
colo_attr
.
saved_grad
.
is_null
(
),
'Gradien accumulation is not supported when reuse_fp16_shard=True'
param
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
reduced_grad
)
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
=
True
param
.
colo_attr
.
saved_grad
.
reset_payload
(
param
.
colo_attr
.
sharded_data_tensor
.
payload
)
else
:
reduced_grad
=
cast_tensor_to_fp32
(
reduced_grad
)
if
param
.
colo_attr
.
saved_grad
.
is_null
():
param
.
colo_attr
.
saved_grad
.
reset_payload
(
reduced_grad
)
else
:
param
.
colo_attr
.
fp16_grad
=
StatefulTensor
(
reduced_grad
.
data
)
param
.
colo_attr
.
saved_grad
.
payload
.
add_
(
reduced_grad
.
view_as
(
param
.
colo_attr
.
saved_grad
.
payload
))
param
.
colo_attr
.
saved_grad
.
trans_state
(
TensorState
.
HOLD
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
for
p
in
self
.
module
.
parameters
()],
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
7c6c427d
...
...
@@ -12,11 +12,12 @@ from colossalai.logging import get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
(
colo_model_data_tensor_move
,
colo_model_tensor_clone
,
from
colossalai.utils.memory_utils.utils
import
(
colo_model_data_tensor_move
_inline
,
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.zero.sharded_param.tensorful_state
import
(
StatefulTensor
,
TensorState
)
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
...
...
@@ -112,7 +113,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
_logger
=
get_dist_logger
(
"ShardedOptimizerV2"
)
# Store fp32 param shards
self
.
master_params
:
Dict
[
Parameter
,
Tensor
]
=
{}
self
.
master_params
:
Dict
[
Parameter
,
Stateful
Tensor
]
=
{}
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
...
...
@@ -123,7 +124,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
self
.
shard_strategy
.
shard
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
master_params
[
p
]
=
cast_tensor_to_fp32
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
).
to
(
self
.
device
)
self
.
master_params
[
p
]
=
StatefulTensor
(
cast_tensor_to_fp32
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
).
to
(
self
.
device
))
if
not
is_param_sharded
:
# In this branch, there's no need to shard param
# So we gather here
...
...
@@ -184,13 +186,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
zero_grad
()
return
# assign master param pointers to p.data.
# We will not trigger data copy here.
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
data
=
self
.
master_params
[
p
]
# Now p.data is sharded
# So optimizer states are sharded naturally
self
.
_prepare_data
()
self
.
_logger
.
debug
(
f
"Before step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
...
...
@@ -201,30 +197,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
_logger
.
debug
(
f
"After step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
# Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
is_param_sharded
=
p
.
colo_attr
.
sharded_data_tensor
.
is_sharded
if
not
is_param_sharded
:
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self
.
shard_strategy
.
shard
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
colo_model_tensor_clone
(
p
.
half
(),
torch
.
cuda
.
current_device
()))
if
not
is_param_sharded
:
# We gather full fp16 param here
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
p
.
data
=
p
.
colo_attr
.
sharded_data_tensor
.
payload
self
.
_write_back_data
()
return
ret
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
...
...
@@ -276,6 +249,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
self
.
optim
.
zero_grad
(
set_to_none
=
True
)
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
colo_attr
.
saved_grad
.
set_null
()
def
sync_grad
(
self
):
pass
...
...
@@ -288,9 +264,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
fp32_shards_used_cuda_margin_mem
=
0
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
shard_mem
=
self
.
master_params
[
p
].
numel
()
*
self
.
master_params
[
p
].
element_size
()
shard_mem
=
self
.
master_params
[
p
].
payload
.
numel
()
*
self
.
master_params
[
p
].
payload
.
element_size
()
if
fp32_shards_used_cuda_margin_mem
+
shard_mem
<
fp32_shards_available_cuda_margin_mem
:
self
.
master_params
[
p
]
=
self
.
master_params
[
p
]
.
to
(
torch
.
cuda
.
current_device
())
colo_model_data_tensor_move_inline
(
self
.
master_params
[
p
]
,
torch
.
cuda
.
current_device
())
p
.
grad
.
data
=
p
.
grad
.
data
.
to
(
torch
.
cuda
.
current_device
())
p
.
colo_attr
.
offload_grad
=
False
fp32_shards_used_cuda_margin_mem
+=
shard_mem
...
...
@@ -298,10 +274,50 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def
_prepare_grads
(
self
):
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
colo_attr
.
saved_grad
.
trans_state
(
TensorState
.
COMPUTE
)
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.colo_attr.saved_grad.payload here
p
.
data
=
p
.
colo_attr
.
saved_grad
.
payload
p
.
grad
=
p
.
colo_attr
.
saved_grad
.
payload
# Set p.data to empty tensor, in case of memory leaking
p
.
colo_attr
.
remove_torch_payload
()
def
_prepare_data
(
self
):
# assign master param pointers to p.data.
# We will not trigger data copy here.
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
self
.
master_params
[
p
].
trans_state
(
TensorState
.
COMPUTE
)
p
.
data
=
self
.
master_params
[
p
].
payload
# Now p.data is sharded
# So optimizer states are sharded naturally
def
_write_back_data
(
self
):
# Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
is_param_sharded
=
p
.
colo_attr
.
sharded_data_tensor
.
is_sharded
if
not
is_param_sharded
:
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self
.
shard_strategy
.
shard
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
colo_model_tensor_clone
(
p
.
half
(),
torch
.
cuda
.
current_device
()))
if
not
is_param_sharded
:
# We gather full fp16 param here
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
p
.
data
=
p
.
colo_attr
.
sharded_data_tensor
.
payload
self
.
master_params
[
p
].
trans_state
(
TensorState
.
HOLD
)
p
.
colo_attr
.
saved_grad
.
set_null
()
colossalai/zero/sharded_param/sharded_param.py
View file @
7c6c427d
...
...
@@ -10,7 +10,6 @@ class ShardedParamV2(object):
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
rm_torch_payload
=
False
)
->
None
:
self
.
_sharded_data_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
)
self
.
fp16_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
saved_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
# This attribute must be initialized in ShardedModel
self
.
offload_grad
:
bool
=
False
...
...
@@ -57,10 +56,6 @@ class ShardedParamV2(object):
_update_mem_use
(
self
.
sharded_data_tensor
.
payload
)
address_set
.
add
(
self
.
sharded_data_tensor
.
payload
.
data_ptr
())
if
not
self
.
fp16_grad
.
is_null
()
and
self
.
fp16_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
fp16_grad
.
payload
)
address_set
.
add
(
self
.
fp16_grad
.
data_ptr
())
if
not
self
.
saved_grad
.
is_null
()
and
self
.
saved_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
saved_grad
.
payload
)
address_set
.
add
(
self
.
saved_grad
.
data_ptr
())
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
7c6c427d
...
...
@@ -63,12 +63,6 @@ def _run_shard_param_v2(rank, world_size, port):
# 4 is size of dummy tensor of param.data
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
sparam
.
fp16_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
).
cuda
().
half
())
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
assert
cuda_mem_use
==
2
*
3
*
2
sparam
.
fp16_grad
=
StatefulTensor
(
None
)
sparam
.
saved_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
sparam
.
remove_torch_payload
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment