Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
05e33b25
Unverified
Commit
05e33b25
authored
Mar 25, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 25, 2022
Browse files
[zero] fix grad offload (#528)
* [zero] fix grad offload * polish code
parent
105c5301
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
7 deletions
+28
-7
colossalai/utils/memory_utils/utils.py
colossalai/utils/memory_utils/utils.py
+17
-0
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+7
-7
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+4
-0
No files found.
colossalai/utils/memory_utils/utils.py
View file @
05e33b25
...
...
@@ -114,3 +114,20 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t_payload
)
t_payload
.
data
=
t_payload
.
data
.
cpu
()
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
t_payload
)
def
colo_model_tensor_clone
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
],
target_device
:
torch
.
device
)
->
torch
.
Tensor
:
"""
Clone a model data tensor
Args:
t (Union[ShardedTensor, torch.Tensor]): a model data tensor
target_device (torch.device): the target device
Returns:
torch.Tensor: a cloned torch tensor
"""
t_payload
=
t
.
payload
if
isinstance
(
t
,
ShardedTensor
)
else
t
ret
=
t_payload
.
to
(
target_device
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
ret
)
return
ret
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
05e33b25
...
...
@@ -11,7 +11,7 @@ from colossalai.engine.ophooks import register_ophooks_recursively
from
colossalai.engine.ophooks.zero_hook
import
ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils.memory_utils.utils
import
colo_model_data_move_to_cpu
,
colo_cuda_memory_capacity
from
colossalai.utils.memory_utils.utils
import
colo_model_data_move_to_cpu
,
colo_cuda_memory_capacity
,
colo_model_tensor_clone
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
...
...
@@ -198,16 +198,16 @@ class ShardedModelV2(nn.Module):
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
if
self
.
reuse_fp16_shard
:
grad
=
p
.
col_attr
.
sharded_data_tensor
.
payload
grad
_payload
=
p
.
col_attr
.
sharded_data_tensor
.
payload
else
:
grad
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
)
grad
_payload
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
)
if
p
.
col_attr
.
offload_grad
:
colo_model_data_move_to_cpu
(
grad
)
grad_payload
=
colo_model_tensor_clone
(
grad_payload
,
torch
.
device
(
'cpu'
)
)
if
p
.
col_attr
.
fp32_grad
is
not
None
:
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
p
.
col_attr
.
fp32_grad
.
add_
(
grad
.
view_as
(
p
.
col_attr
.
fp32_grad
))
grad
=
p
.
col_attr
.
fp32_grad
p
.
grad
.
data
=
grad
p
.
col_attr
.
fp32_grad
.
add_
(
grad
_payload
.
view_as
(
p
.
col_attr
.
fp32_grad
))
grad
_payload
=
p
.
col_attr
.
fp32_grad
p
.
grad
.
data
=
grad
_payload
p
.
col_attr
.
fp16_grad
=
None
p
.
col_attr
.
fp32_grad
=
None
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
05e33b25
...
...
@@ -9,6 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
torch
import
Tensor
...
...
@@ -217,6 +218,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We must set grad to None
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
p
.
grad
)
self
.
optim
.
zero_grad
(
set_to_none
=
True
)
def
sync_grad
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment