Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f552b112
Unverified
Commit
f552b112
authored
Mar 30, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 30, 2022
Browse files
[zero] label state for param fp16 and grad (#551)
parent
92f42248
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
75 additions
and
39 deletions
+75
-39
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+17
-0
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+43
-18
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+0
-1
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+2
-5
colossalai/zero/sharded_param/sharded_tensor.py
colossalai/zero/sharded_param/sharded_tensor.py
+4
-6
colossalai/zero/sharded_param/tensorful_state.py
colossalai/zero/sharded_param/tensorful_state.py
+8
-8
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+1
-1
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
f552b112
...
...
@@ -6,6 +6,7 @@ from colossalai.registry import OPHOOKS
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
._base_ophook
import
BaseOpHook
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move_inline
...
...
@@ -42,7 +43,13 @@ class ZeroHook(BaseOpHook):
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_FWD
)
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'col_attr'
)
...
...
@@ -65,7 +72,10 @@ class ZeroHook(BaseOpHook):
if
param
.
col_attr
.
bwd_count
==
0
:
# We haven't stored local accumulated grad yet
assert
param
.
col_attr
.
fp32_grad
.
is_null
()
# Allocate grad fp32 memory space here
param
.
col_attr
.
fp32_grad
.
reset_payload
(
param
.
grad
.
data
)
# TODO(jiaruifang) we should set grad fp16 state to HOLD here.
param
.
grad
=
None
else
:
# We have stored local accumulated grad
...
...
@@ -75,12 +85,19 @@ class ZeroHook(BaseOpHook):
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'col_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
remove_torch_payload
()
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
f552b112
...
...
@@ -25,7 +25,7 @@ from torch.nn.parameter import Parameter
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
)
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
,
TensorState
class
ShardedModelV2
(
nn
.
Module
):
...
...
@@ -158,12 +158,25 @@ class ShardedModelV2(nn.Module):
f
.
write
(
str
(
self
.
_memstats_collector
.
non_model_data_cuda_GB
))
f
.
write
(
'
\n
'
)
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
def
_pre_
forward
_operations
(
self
)
:
if
self
.
_iter_cnter
==
0
and
self
.
_memstats_collector
:
# the ope
a
rtion will affect the
flag
in ZeroHook
# the oper
a
tion will affect the
memory tracer behavior
in ZeroHook
self
.
_memstats_collector
.
start_collection
()
for
p
in
self
.
module
.
parameters
():
if
hasattr
(
p
,
'col_attr'
):
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
def
_post_forward_operations
(
self
):
for
p
in
self
.
module
.
parameters
():
if
hasattr
(
p
,
'col_attr'
):
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
self
.
_pre_forward_operations
()
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
self
.
_post_forward_operations
()
return
outputs
def
backward
(
self
,
loss
):
...
...
@@ -195,9 +208,15 @@ class ShardedModelV2(nn.Module):
def
_post_backward_operations
(
self
)
->
None
:
"""
The method includes operations required to be processed after backward
1. update memory tracer.
2. flush the gradient in buckets. Reducing partial gradients in each process.
3. shard tensors not dealed in the zero hook
4. move sharded param grad payload to param.grad
"""
# 1. update memory tracer.
self
.
_update_memstats
()
# 2. flush the gradient in buckets. Reducing partial gradients in each process.
if
self
.
_require_backward_grad_sync
:
# Flush any unreduced buckets in the post_backward stream.
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
...
...
@@ -207,44 +226,50 @@ class ShardedModelV2(nn.Module):
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch
.
cuda
.
current_stream
().
synchronize
()
self
.
reducer
.
free
()
#
In case some post bwd hook is not fired
#
3. shard tensors not dealed in the zero hook
if
self
.
shard_param
:
tensor_list
=
[]
for
p
in
self
.
module
.
parameters
():
if
not
p
.
col_attr
.
param_is_sharded
:
tensor_list
.
append
(
p
.
col_attr
.
sharded_data_tensor
)
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
# 4. move sharded param grad payload to param.grad
for
p
in
self
.
module
.
parameters
():
p
.
col_attr
.
bwd_count
=
0
if
not
p
.
requires_grad
:
continue
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# sync passes, if desired.
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
# NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient allreducing between process group.
# If _require_backward_grad_sync is True,
# p.grad remains the accumulated unsharded gradient from prior no-sync passes.
# We also allows to interleave no-sync pass with sync passes, if desired.
if
not
self
.
_require_backward_grad_sync
:
continue
# Write grad back to p.grad and set p.col_attr.grad to None
# Write grad payload kept by sharded param back to p.grad,
# and set p.col_attr.grad to None
# As sharded optimizer only update a shard of param,
# no matter whether we shard param in sharded model
# We have to make sure the grad is a flat tensor shard
# If world size == 1 and sharded
param
,
# If world size == 1 and
param is
sharded,
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
if
self
.
reuse_fp16_shard
:
grad_payload
=
p
.
col_attr
.
sharded_data_tensor
.
payload
grad_
fp16_
payload
=
p
.
col_attr
.
sharded_data_tensor
.
payload
else
:
grad_payload
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
.
payload
)
assert
isinstance
(
grad_payload
,
torch
.
Tensor
)
grad_
fp16_
payload
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
.
payload
)
assert
isinstance
(
grad_
fp16_
payload
,
torch
.
Tensor
)
if
p
.
col_attr
.
offload_grad
:
colo_model_data_move_to_cpu
(
grad_payload
)
colo_model_data_move_to_cpu
(
grad_
fp16_
payload
)
if
not
p
.
col_attr
.
fp32_grad
.
is_null
():
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
p
.
col_attr
.
fp32_grad
.
payload
.
add_
(
grad_payload
.
view_as
(
p
.
col_attr
.
fp32_grad
.
payload
))
grad_payload
=
p
.
col_attr
.
fp32_grad
.
payload
p
.
grad
.
data
=
grad_payload
p
.
col_attr
.
fp32_grad
.
payload
.
add_
(
grad_fp16_payload
.
view_as
(
p
.
col_attr
.
fp32_grad
.
payload
))
grad_fp16_payload
=
p
.
col_attr
.
fp32_grad
.
payload
p
.
col_attr
.
fp32_grad
.
set_null
()
p
.
grad
.
data
=
grad_fp16_payload
p
.
col_attr
.
fp16_grad
.
set_null
()
p
.
col_attr
.
fp32_grad
.
set_null
()
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
f552b112
...
...
@@ -79,7 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval
:
float
=
1000
,
hysteresis
:
float
=
2
,
max_scale
:
int
=
2
**
32
,
use_memory_tracer
=
False
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
...
...
colossalai/zero/sharded_param/sharded_param.py
View file @
f552b112
...
...
@@ -8,11 +8,8 @@ from .tensorful_state import StatefulTensor, TensorState
class
ShardedParamV2
(
object
):
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
rm_torch_payload
=
False
)
->
None
:
self
.
_sharded_data_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
,
process_group
)
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
rm_torch_payload
=
False
)
->
None
:
self
.
_sharded_data_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
)
self
.
fp16_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
fp32_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
# This attribute must be initialized in ShardedModel
...
...
colossalai/zero/sharded_param/sharded_tensor.py
View file @
f552b112
import
torch
import
torch.distributed
as
dist
from
typing
import
Optional
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
,
TensorState
from
typing
import
Optional
class
ShardedTensor
(
StatefulTensor
):
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
)
->
None
:
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
state
:
TensorState
=
TensorState
.
HOLD
)
->
None
:
r
"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
super
().
__init__
(
tensor
)
self
.
trans_state
(
TensorState
.
HOLD
)
super
().
__init__
(
tensor
,
state
)
# kept the shape, numel and dtype of the init tensor.
self
.
_origin_shape
=
tensor
.
shape
self
.
_origin_numel
=
tensor
.
numel
()
self
.
_origin_dtype
=
tensor
.
dtype
self
.
_is_sharded
=
False
@
property
...
...
colossalai/zero/sharded_param/tensorful_state.py
View file @
f552b112
from
enum
import
Enum
from
logg
ing
import
NullHandler
from
typ
ing
import
Optional
import
torch
...
...
@@ -8,22 +8,22 @@ class TensorState(Enum):
HOLD
=
1
HOLD_AFTER_FWD
=
2
HOLD_AFTER_BWD
=
3
COMPUTE
=
4
class
StatefulTensor
(
object
):
"""A Structure stores a Torch Tensor and labeled states.
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
state
:
TensorState
=
TensorState
.
HOLD
)
->
None
:
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
state
:
Optional
[
TensorState
]
=
TensorState
.
HOLD
)
->
None
:
self
.
_state
=
state
if
state
is
not
TensorState
.
FREE
:
self
.
_payload
=
tensor
else
:
self
.
_payload
=
None
self
.
_payload
=
tensor
if
self
.
_state
==
TensorState
.
FREE
:
assert
self
.
_payload
is
None
,
f
"payload has to None if
{
self
.
_state
}
"
def
data_ptr
(
self
):
if
self
.
_payload
is
None
:
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
f552b112
...
...
@@ -48,7 +48,7 @@ def _run_shard_param_v2(rank, world_size, port):
param
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
2
,
3
))
param_ref
=
deepcopy
(
param
)
sparam
=
ShardedParamV2
(
param
=
param
,
process_group
=
None
)
sparam
=
ShardedParamV2
(
param
=
param
)
allclose
(
sparam
.
sharded_data_tensor
.
payload
,
param_ref
.
data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment