Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
014bac0c
Unverified
Commit
014bac0c
authored
Mar 30, 2022
by
ver217
Committed by
GitHub
Mar 30, 2022
Browse files
[zero] hijack p.grad in sharded model (#554)
* hijack p.grad in sharded model * polish comments * polish comments
parent
f552b112
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
45 additions
and
55 deletions
+45
-55
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+3
-16
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+13
-17
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+20
-9
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+4
-9
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+2
-1
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+3
-3
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
014bac0c
...
@@ -9,7 +9,9 @@ from colossalai.zero.shard_utils import BaseShardStrategy
...
@@ -9,7 +9,9 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
._base_ophook
import
BaseOpHook
from
._base_ophook
import
BaseOpHook
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move_inline
from
colossalai.utils.memory_utils.utils
import
\
colo_model_data_tensor_move_inline
@
OPHOOKS
.
register_module
@
OPHOOKS
.
register_module
...
@@ -67,21 +69,6 @@ class ZeroHook(BaseOpHook):
...
@@ -67,21 +69,6 @@ class ZeroHook(BaseOpHook):
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
colo_model_data_tensor_move_inline
(
param
.
col_attr
.
sharded_data_tensor
,
self
.
computing_device
)
colo_model_data_tensor_move_inline
(
param
.
col_attr
.
sharded_data_tensor
,
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
sharded_data_tensor
.
payload
param
.
data
=
param
.
col_attr
.
sharded_data_tensor
.
payload
# Store local accumulated grad shard
if
param
.
grad
is
not
None
:
if
param
.
col_attr
.
bwd_count
==
0
:
# We haven't stored local accumulated grad yet
assert
param
.
col_attr
.
fp32_grad
.
is_null
()
# Allocate grad fp32 memory space here
param
.
col_attr
.
fp32_grad
.
reset_payload
(
param
.
grad
.
data
)
# TODO(jiaruifang) we should set grad fp16 state to HOLD here.
param
.
grad
=
None
else
:
# We have stored local accumulated grad
# The grad here must be locally computed full grad in this backward pass
assert
param
.
grad
.
shape
==
param
.
col_attr
.
sharded_data_tensor
.
origin_shape
param
.
col_attr
.
bwd_count
+=
1
if
self
.
_memstarts_collector
:
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
self
.
_memstarts_collector
.
sample_memstats
()
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
014bac0c
...
@@ -12,20 +12,18 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
...
@@ -12,20 +12,18 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_cuda_memory_capacity
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
(
colo_cuda_memory_capacity
,
colo_model_data_move_to_cpu
)
from
colossalai.utils.memory_utils.utils
import
(
colo_cuda_memory_capacity
,
colo_model_data_move_to_cpu
)
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_param.tensorful_state
import
(
StatefulTensor
,
TensorState
)
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
)
get_gradient_predivide_factor
)
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
,
TensorState
class
ShardedModelV2
(
nn
.
Module
):
class
ShardedModelV2
(
nn
.
Module
):
...
@@ -233,11 +231,11 @@ class ShardedModelV2(nn.Module):
...
@@ -233,11 +231,11 @@ class ShardedModelV2(nn.Module):
if
not
p
.
col_attr
.
param_is_sharded
:
if
not
p
.
col_attr
.
param_is_sharded
:
tensor_list
.
append
(
p
.
col_attr
.
sharded_data_tensor
)
tensor_list
.
append
(
p
.
col_attr
.
sharded_data_tensor
)
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
p
.
col_attr
.
remove_torch_payload
()
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
# 4. move sharded param grad payload to param.grad
# 4. move sharded param grad payload to param.grad
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
p
.
col_attr
.
bwd_count
=
0
if
not
p
.
requires_grad
:
if
not
p
.
requires_grad
:
continue
continue
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
...
@@ -247,14 +245,10 @@ class ShardedModelV2(nn.Module):
...
@@ -247,14 +245,10 @@ class ShardedModelV2(nn.Module):
# We also allows to interleave no-sync pass with sync passes, if desired.
# We also allows to interleave no-sync pass with sync passes, if desired.
if
not
self
.
_require_backward_grad_sync
:
if
not
self
.
_require_backward_grad_sync
:
continue
continue
# Write grad payload kept by sharded param back to p.grad,
# Reduced grad is saved in `p.col_attr.saved_grad`
# and set p.col_attr.grad to None
# It can be on CPU or CUDA
# As sharded optimizer only update a shard of param,
# It can be fp16 or fp32
# no matter whether we shard param in sharded model
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
# We have to make sure the grad is a flat tensor shard
# If world size == 1 and param is sharded,
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
if
self
.
reuse_fp16_shard
:
if
self
.
reuse_fp16_shard
:
grad_fp16_payload
=
p
.
col_attr
.
sharded_data_tensor
.
payload
grad_fp16_payload
=
p
.
col_attr
.
sharded_data_tensor
.
payload
else
:
else
:
...
@@ -262,13 +256,15 @@ class ShardedModelV2(nn.Module):
...
@@ -262,13 +256,15 @@ class ShardedModelV2(nn.Module):
assert
isinstance
(
grad_fp16_payload
,
torch
.
Tensor
)
assert
isinstance
(
grad_fp16_payload
,
torch
.
Tensor
)
if
p
.
col_attr
.
offload_grad
:
if
p
.
col_attr
.
offload_grad
:
colo_model_data_move_to_cpu
(
grad_fp16_payload
)
colo_model_data_move_to_cpu
(
grad_fp16_payload
)
if
not
p
.
col_attr
.
fp32
_grad
.
is_null
():
if
not
p
.
col_attr
.
saved
_grad
.
is_null
():
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
p
.
col_attr
.
fp32_grad
.
payload
.
add_
(
grad_fp16_payload
.
view_as
(
p
.
col_attr
.
fp32_grad
.
payload
))
# Accumulate grad, saved grad must be fp32
grad_fp16_payload
=
p
.
col_attr
.
fp32_grad
.
payload
p
.
col_attr
.
saved_grad
.
reset_payload
(
cast_tensor_to_fp32
(
p
.
col_attr
.
saved_grad
.
payload
))
p
.
col_attr
.
fp32_grad
.
set_null
()
p
.
col_attr
.
saved_grad
.
payload
.
add_
(
grad_fp16_payload
.
view_as
(
p
.
col_attr
.
saved_grad
.
payload
))
else
:
p
.
col_attr
.
saved_grad
.
reset_payload
(
grad_fp16_payload
)
p
.
grad
.
data
=
grad_fp16_payload
p
.
grad
=
None
p
.
col_attr
.
fp16_grad
.
set_null
()
p
.
col_attr
.
fp16_grad
.
set_null
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
014bac0c
...
@@ -5,23 +5,22 @@ from typing import Dict, Optional, Tuple
...
@@ -5,23 +5,22 @@ from typing import Dict, Optional, Tuple
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
colossalai.amp.naive_amp.grad_scaler
import
DynamicGradScaler
from
colossalai.amp.naive_amp.grad_scaler
import
DynamicGradScaler
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils.memory_utils.utils
import
(
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
(
colo_model_data_tensor_move
,
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
torch
import
Tensor
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_tensor_mem_usage
from
torch.distributed
import
ProcessGroup
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
class
OptimState
(
Enum
):
class
OptimState
(
Enum
):
...
@@ -170,6 +169,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -170,6 +169,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
return
cuda_use
,
cpu_use
return
cuda_use
,
cpu_use
def
step
(
self
,
*
args
,
**
kwargs
):
def
step
(
self
,
*
args
,
**
kwargs
):
self
.
_prepare_grads
()
self
.
_maybe_move_fp32_shards
()
self
.
_maybe_move_fp32_shards
()
# unscale grads if scaled
# unscale grads if scaled
...
@@ -294,3 +294,14 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -294,3 +294,14 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
p
.
grad
.
data
=
p
.
grad
.
data
.
to
(
torch
.
cuda
.
current_device
())
p
.
grad
.
data
=
p
.
grad
.
data
.
to
(
torch
.
cuda
.
current_device
())
p
.
col_attr
.
offload_grad
=
False
p
.
col_attr
.
offload_grad
=
False
fp32_shards_used_cuda_margin_mem
+=
shard_mem
fp32_shards_used_cuda_margin_mem
+=
shard_mem
def
_prepare_grads
(
self
):
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.col_attr.saved_grad.payload here
p
.
data
=
p
.
col_attr
.
saved_grad
.
payload
p
.
grad
=
p
.
col_attr
.
saved_grad
.
payload
p
.
col_attr
.
saved_grad
.
set_null
()
colossalai/zero/sharded_param/sharded_param.py
View file @
014bac0c
...
@@ -11,7 +11,7 @@ class ShardedParamV2(object):
...
@@ -11,7 +11,7 @@ class ShardedParamV2(object):
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
rm_torch_payload
=
False
)
->
None
:
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
rm_torch_payload
=
False
)
->
None
:
self
.
_sharded_data_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
)
self
.
_sharded_data_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
)
self
.
fp16_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
fp16_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
fp32
_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
saved
_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
# This attribute must be initialized in ShardedModel
# This attribute must be initialized in ShardedModel
self
.
offload_grad
:
bool
=
False
self
.
offload_grad
:
bool
=
False
...
@@ -24,11 +24,6 @@ class ShardedParamV2(object):
...
@@ -24,11 +24,6 @@ class ShardedParamV2(object):
if
rm_torch_payload
:
if
rm_torch_payload
:
self
.
remove_torch_payload
()
self
.
remove_torch_payload
()
# Backward count for handle local grad accumulation
# This value will increment by 1 in every pre-bwd hook
# And will be reset to 0 in every final-bwd hook
self
.
bwd_count
=
0
def
remove_torch_payload
(
self
):
def
remove_torch_payload
(
self
):
self
.
param
.
data
=
torch
.
empty
([],
dtype
=
self
.
param
.
dtype
,
device
=
self
.
param
.
device
)
self
.
param
.
data
=
torch
.
empty
([],
dtype
=
self
.
param
.
dtype
,
device
=
self
.
param
.
device
)
...
@@ -66,9 +61,9 @@ class ShardedParamV2(object):
...
@@ -66,9 +61,9 @@ class ShardedParamV2(object):
_update_mem_use
(
self
.
fp16_grad
.
payload
)
_update_mem_use
(
self
.
fp16_grad
.
payload
)
address_set
.
add
(
self
.
fp16_grad
.
data_ptr
())
address_set
.
add
(
self
.
fp16_grad
.
data_ptr
())
if
not
self
.
fp32
_grad
.
is_null
()
and
self
.
fp32
_grad
.
data_ptr
()
not
in
address_set
:
if
not
self
.
saved
_grad
.
is_null
()
and
self
.
saved
_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
fp32
_grad
.
payload
)
_update_mem_use
(
self
.
saved
_grad
.
payload
)
address_set
.
add
(
self
.
fp32
_grad
.
data_ptr
())
address_set
.
add
(
self
.
saved
_grad
.
data_ptr
())
if
self
.
param
.
data
is
not
None
and
self
.
param
.
data
.
data_ptr
()
not
in
address_set
:
if
self
.
param
.
data
is
not
None
and
self
.
param
.
data
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
param
.
data
)
_update_mem_use
(
self
.
param
.
data
)
...
...
tests/test_zero_data_parallel/common.py
View file @
014bac0c
...
@@ -92,7 +92,8 @@ def check_params(model, zero_model, loose=False):
...
@@ -92,7 +92,8 @@ def check_params(model, zero_model, loose=False):
def
check_grads_padding
(
model
,
zero_model
,
loose
=
False
):
def
check_grads_padding
(
model
,
zero_model
,
loose
=
False
):
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
zero_grad
=
zero_p
.
grad
.
clone
().
to
(
p
.
device
)
# zero_grad = zero_p.grad.clone().to(p.device)
zero_grad
=
zero_p
.
col_attr
.
saved_grad
.
payload
.
clone
().
to
(
p
.
device
)
chunks
=
torch
.
flatten
(
p
.
grad
).
chunk
(
dist
.
get_world_size
())
chunks
=
torch
.
flatten
(
p
.
grad
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
if
rank
>=
len
(
chunks
):
continue
continue
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
014bac0c
...
@@ -53,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port):
...
@@ -53,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port):
allclose
(
sparam
.
sharded_data_tensor
.
payload
,
param_ref
.
data
)
allclose
(
sparam
.
sharded_data_tensor
.
payload
,
param_ref
.
data
)
# Test get memory usage
# Test get memory usage
sparam
.
fp32
_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
sparam
.
saved
_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
,
f
"cpu_mem_use:
{
cpu_mem_use
}
"
assert
cpu_mem_use
==
2
*
3
*
4
*
2
,
f
"cpu_mem_use:
{
cpu_mem_use
}
"
...
@@ -69,7 +69,7 @@ def _run_shard_param_v2(rank, world_size, port):
...
@@ -69,7 +69,7 @@ def _run_shard_param_v2(rank, world_size, port):
assert
cuda_mem_use
==
2
*
3
*
2
assert
cuda_mem_use
==
2
*
3
*
2
sparam
.
fp16_grad
=
StatefulTensor
(
None
)
sparam
.
fp16_grad
=
StatefulTensor
(
None
)
sparam
.
fp32
_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
sparam
.
saved
_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
sparam
.
remove_torch_payload
()
sparam
.
remove_torch_payload
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
...
@@ -83,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port):
...
@@ -83,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port):
assert
cuda_mem_use
==
0
assert
cuda_mem_use
==
0
# reuse torch grad for sparam
# reuse torch grad for sparam
sparam
.
fp32
_grad
=
StatefulTensor
(
param
.
grad
)
sparam
.
saved
_grad
=
StatefulTensor
(
param
.
grad
)
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
assert
cpu_mem_use
==
2
*
3
*
4
*
2
assert
cuda_mem_use
==
0
assert
cuda_mem_use
==
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment