Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
014bac0c
Unverified
Commit
014bac0c
authored
Mar 30, 2022
by
ver217
Committed by
GitHub
Mar 30, 2022
Browse files
[zero] hijack p.grad in sharded model (#554)
* hijack p.grad in sharded model * polish comments * polish comments
parent
f552b112
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
45 additions
and
55 deletions
+45
-55
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+3
-16
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+13
-17
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+20
-9
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+4
-9
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+2
-1
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+3
-3
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
014bac0c
...
...
@@ -9,7 +9,9 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
._base_ophook
import
BaseOpHook
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move_inline
from
colossalai.utils.memory_utils.utils
import
\
colo_model_data_tensor_move_inline
@
OPHOOKS
.
register_module
...
...
@@ -67,21 +69,6 @@ class ZeroHook(BaseOpHook):
for
param
in
module
.
parameters
(
recurse
=
False
):
colo_model_data_tensor_move_inline
(
param
.
col_attr
.
sharded_data_tensor
,
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
sharded_data_tensor
.
payload
# Store local accumulated grad shard
if
param
.
grad
is
not
None
:
if
param
.
col_attr
.
bwd_count
==
0
:
# We haven't stored local accumulated grad yet
assert
param
.
col_attr
.
fp32_grad
.
is_null
()
# Allocate grad fp32 memory space here
param
.
col_attr
.
fp32_grad
.
reset_payload
(
param
.
grad
.
data
)
# TODO(jiaruifang) we should set grad fp16 state to HOLD here.
param
.
grad
=
None
else
:
# We have stored local accumulated grad
# The grad here must be locally computed full grad in this backward pass
assert
param
.
grad
.
shape
==
param
.
col_attr
.
sharded_data_tensor
.
origin_shape
param
.
col_attr
.
bwd_count
+=
1
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
014bac0c
...
...
@@ -12,20 +12,18 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_cuda_memory_capacity
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
(
colo_cuda_memory_capacity
,
colo_model_data_move_to_cpu
)
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_param.tensorful_state
import
(
StatefulTensor
,
TensorState
)
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
)
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
,
TensorState
class
ShardedModelV2
(
nn
.
Module
):
...
...
@@ -233,11 +231,11 @@ class ShardedModelV2(nn.Module):
if
not
p
.
col_attr
.
param_is_sharded
:
tensor_list
.
append
(
p
.
col_attr
.
sharded_data_tensor
)
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
p
.
col_attr
.
remove_torch_payload
()
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
# 4. move sharded param grad payload to param.grad
for
p
in
self
.
module
.
parameters
():
p
.
col_attr
.
bwd_count
=
0
if
not
p
.
requires_grad
:
continue
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
...
...
@@ -247,14 +245,10 @@ class ShardedModelV2(nn.Module):
# We also allows to interleave no-sync pass with sync passes, if desired.
if
not
self
.
_require_backward_grad_sync
:
continue
# Write grad payload kept by sharded param back to p.grad,
# and set p.col_attr.grad to None
# As sharded optimizer only update a shard of param,
# no matter whether we shard param in sharded model
# We have to make sure the grad is a flat tensor shard
# If world size == 1 and param is sharded,
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
# Reduced grad is saved in `p.col_attr.saved_grad`
# It can be on CPU or CUDA
# It can be fp16 or fp32
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
if
self
.
reuse_fp16_shard
:
grad_fp16_payload
=
p
.
col_attr
.
sharded_data_tensor
.
payload
else
:
...
...
@@ -262,13 +256,15 @@ class ShardedModelV2(nn.Module):
assert
isinstance
(
grad_fp16_payload
,
torch
.
Tensor
)
if
p
.
col_attr
.
offload_grad
:
colo_model_data_move_to_cpu
(
grad_fp16_payload
)
if
not
p
.
col_attr
.
fp32
_grad
.
is_null
():
if
not
p
.
col_attr
.
saved
_grad
.
is_null
():
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
p
.
col_attr
.
fp32_grad
.
payload
.
add_
(
grad_fp16_payload
.
view_as
(
p
.
col_attr
.
fp32_grad
.
payload
))
grad_fp16_payload
=
p
.
col_attr
.
fp32_grad
.
payload
p
.
col_attr
.
fp32_grad
.
set_null
()
# Accumulate grad, saved grad must be fp32
p
.
col_attr
.
saved_grad
.
reset_payload
(
cast_tensor_to_fp32
(
p
.
col_attr
.
saved_grad
.
payload
))
p
.
col_attr
.
saved_grad
.
payload
.
add_
(
grad_fp16_payload
.
view_as
(
p
.
col_attr
.
saved_grad
.
payload
))
else
:
p
.
col_attr
.
saved_grad
.
reset_payload
(
grad_fp16_payload
)
p
.
grad
.
data
=
grad_fp16_payload
p
.
grad
=
None
p
.
col_attr
.
fp16_grad
.
set_null
()
@
torch
.
no_grad
()
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
014bac0c
...
...
@@ -5,23 +5,22 @@ from typing import Dict, Optional, Tuple
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
colossalai.amp.naive_amp.grad_scaler
import
DynamicGradScaler
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils.memory_utils.utils
import
(
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
(
colo_model_data_tensor_move
,
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_tensor_mem_usage
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
class
OptimState
(
Enum
):
...
...
@@ -170,6 +169,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
return
cuda_use
,
cpu_use
def
step
(
self
,
*
args
,
**
kwargs
):
self
.
_prepare_grads
()
self
.
_maybe_move_fp32_shards
()
# unscale grads if scaled
...
...
@@ -294,3 +294,14 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
p
.
grad
.
data
=
p
.
grad
.
data
.
to
(
torch
.
cuda
.
current_device
())
p
.
col_attr
.
offload_grad
=
False
fp32_shards_used_cuda_margin_mem
+=
shard_mem
def
_prepare_grads
(
self
):
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.col_attr.saved_grad.payload here
p
.
data
=
p
.
col_attr
.
saved_grad
.
payload
p
.
grad
=
p
.
col_attr
.
saved_grad
.
payload
p
.
col_attr
.
saved_grad
.
set_null
()
colossalai/zero/sharded_param/sharded_param.py
View file @
014bac0c
...
...
@@ -11,7 +11,7 @@ class ShardedParamV2(object):
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
rm_torch_payload
=
False
)
->
None
:
self
.
_sharded_data_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
)
self
.
fp16_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
fp32
_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
saved
_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
# This attribute must be initialized in ShardedModel
self
.
offload_grad
:
bool
=
False
...
...
@@ -24,11 +24,6 @@ class ShardedParamV2(object):
if
rm_torch_payload
:
self
.
remove_torch_payload
()
# Backward count for handle local grad accumulation
# This value will increment by 1 in every pre-bwd hook
# And will be reset to 0 in every final-bwd hook
self
.
bwd_count
=
0
def
remove_torch_payload
(
self
):
self
.
param
.
data
=
torch
.
empty
([],
dtype
=
self
.
param
.
dtype
,
device
=
self
.
param
.
device
)
...
...
@@ -66,9 +61,9 @@ class ShardedParamV2(object):
_update_mem_use
(
self
.
fp16_grad
.
payload
)
address_set
.
add
(
self
.
fp16_grad
.
data_ptr
())
if
not
self
.
fp32
_grad
.
is_null
()
and
self
.
fp32
_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
fp32
_grad
.
payload
)
address_set
.
add
(
self
.
fp32
_grad
.
data_ptr
())
if
not
self
.
saved
_grad
.
is_null
()
and
self
.
saved
_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
saved
_grad
.
payload
)
address_set
.
add
(
self
.
saved
_grad
.
data_ptr
())
if
self
.
param
.
data
is
not
None
and
self
.
param
.
data
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
param
.
data
)
...
...
tests/test_zero_data_parallel/common.py
View file @
014bac0c
...
...
@@ -92,7 +92,8 @@ def check_params(model, zero_model, loose=False):
def
check_grads_padding
(
model
,
zero_model
,
loose
=
False
):
rank
=
dist
.
get_rank
()
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
zero_grad
=
zero_p
.
grad
.
clone
().
to
(
p
.
device
)
# zero_grad = zero_p.grad.clone().to(p.device)
zero_grad
=
zero_p
.
col_attr
.
saved_grad
.
payload
.
clone
().
to
(
p
.
device
)
chunks
=
torch
.
flatten
(
p
.
grad
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
continue
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
014bac0c
...
...
@@ -53,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port):
allclose
(
sparam
.
sharded_data_tensor
.
payload
,
param_ref
.
data
)
# Test get memory usage
sparam
.
fp32
_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
sparam
.
saved
_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
,
f
"cpu_mem_use:
{
cpu_mem_use
}
"
...
...
@@ -69,7 +69,7 @@ def _run_shard_param_v2(rank, world_size, port):
assert
cuda_mem_use
==
2
*
3
*
2
sparam
.
fp16_grad
=
StatefulTensor
(
None
)
sparam
.
fp32
_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
sparam
.
saved
_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
sparam
.
remove_torch_payload
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
...
...
@@ -83,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port):
assert
cuda_mem_use
==
0
# reuse torch grad for sparam
sparam
.
fp32
_grad
=
StatefulTensor
(
param
.
grad
)
sparam
.
saved
_grad
=
StatefulTensor
(
param
.
grad
)
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
assert
cuda_mem_use
==
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment