Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f552b112
Unverified
Commit
f552b112
authored
Mar 30, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 30, 2022
Browse files
[zero] label state for param fp16 and grad (#551)
parent
92f42248
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
75 additions
and
39 deletions
+75
-39
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+17
-0
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+43
-18
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+0
-1
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+2
-5
colossalai/zero/sharded_param/sharded_tensor.py
colossalai/zero/sharded_param/sharded_tensor.py
+4
-6
colossalai/zero/sharded_param/tensorful_state.py
colossalai/zero/sharded_param/tensorful_state.py
+8
-8
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+1
-1
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
f552b112
...
@@ -6,6 +6,7 @@ from colossalai.registry import OPHOOKS
...
@@ -6,6 +6,7 @@ from colossalai.registry import OPHOOKS
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
._base_ophook
import
BaseOpHook
from
._base_ophook
import
BaseOpHook
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move_inline
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move_inline
...
@@ -42,7 +43,13 @@ class ZeroHook(BaseOpHook):
...
@@ -42,7 +43,13 @@ class ZeroHook(BaseOpHook):
if
self
.
_memstarts_collector
:
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
self
.
_memstarts_collector
.
sample_memstats
()
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_FWD
)
tensor_list
=
[]
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col_attr'
)
...
@@ -65,7 +72,10 @@ class ZeroHook(BaseOpHook):
...
@@ -65,7 +72,10 @@ class ZeroHook(BaseOpHook):
if
param
.
col_attr
.
bwd_count
==
0
:
if
param
.
col_attr
.
bwd_count
==
0
:
# We haven't stored local accumulated grad yet
# We haven't stored local accumulated grad yet
assert
param
.
col_attr
.
fp32_grad
.
is_null
()
assert
param
.
col_attr
.
fp32_grad
.
is_null
()
# Allocate grad fp32 memory space here
param
.
col_attr
.
fp32_grad
.
reset_payload
(
param
.
grad
.
data
)
param
.
col_attr
.
fp32_grad
.
reset_payload
(
param
.
grad
.
data
)
# TODO(jiaruifang) we should set grad fp16 state to HOLD here.
param
.
grad
=
None
param
.
grad
=
None
else
:
else
:
# We have stored local accumulated grad
# We have stored local accumulated grad
...
@@ -75,12 +85,19 @@ class ZeroHook(BaseOpHook):
...
@@ -75,12 +85,19 @@ class ZeroHook(BaseOpHook):
if
self
.
_memstarts_collector
:
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
self
.
_memstarts_collector
.
sample_memstats
()
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
tensor_list
=
[]
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
sharded_data_tensor
)
tensor_list
.
append
(
param
.
col_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
remove_torch_payload
()
param
.
col_attr
.
remove_torch_payload
()
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
f552b112
...
@@ -25,7 +25,7 @@ from torch.nn.parameter import Parameter
...
@@ -25,7 +25,7 @@ from torch.nn.parameter import Parameter
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
)
get_gradient_predivide_factor
)
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
,
TensorState
class
ShardedModelV2
(
nn
.
Module
):
class
ShardedModelV2
(
nn
.
Module
):
...
@@ -158,12 +158,25 @@ class ShardedModelV2(nn.Module):
...
@@ -158,12 +158,25 @@ class ShardedModelV2(nn.Module):
f
.
write
(
str
(
self
.
_memstats_collector
.
non_model_data_cuda_GB
))
f
.
write
(
str
(
self
.
_memstats_collector
.
non_model_data_cuda_GB
))
f
.
write
(
'
\n
'
)
f
.
write
(
'
\n
'
)
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
def
_pre_
forward
_operations
(
self
)
:
if
self
.
_iter_cnter
==
0
and
self
.
_memstats_collector
:
if
self
.
_iter_cnter
==
0
and
self
.
_memstats_collector
:
# the ope
a
rtion will affect the
flag
in ZeroHook
# the oper
a
tion will affect the
memory tracer behavior
in ZeroHook
self
.
_memstats_collector
.
start_collection
()
self
.
_memstats_collector
.
start_collection
()
for
p
in
self
.
module
.
parameters
():
if
hasattr
(
p
,
'col_attr'
):
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
def
_post_forward_operations
(
self
):
for
p
in
self
.
module
.
parameters
():
if
hasattr
(
p
,
'col_attr'
):
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
self
.
_pre_forward_operations
()
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
self
.
_post_forward_operations
()
return
outputs
return
outputs
def
backward
(
self
,
loss
):
def
backward
(
self
,
loss
):
...
@@ -195,9 +208,15 @@ class ShardedModelV2(nn.Module):
...
@@ -195,9 +208,15 @@ class ShardedModelV2(nn.Module):
def
_post_backward_operations
(
self
)
->
None
:
def
_post_backward_operations
(
self
)
->
None
:
"""
"""
The method includes operations required to be processed after backward
The method includes operations required to be processed after backward
1. update memory tracer.
2. flush the gradient in buckets. Reducing partial gradients in each process.
3. shard tensors not dealed in the zero hook
4. move sharded param grad payload to param.grad
"""
"""
# 1. update memory tracer.
self
.
_update_memstats
()
self
.
_update_memstats
()
# 2. flush the gradient in buckets. Reducing partial gradients in each process.
if
self
.
_require_backward_grad_sync
:
if
self
.
_require_backward_grad_sync
:
# Flush any unreduced buckets in the post_backward stream.
# Flush any unreduced buckets in the post_backward stream.
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
...
@@ -207,45 +226,51 @@ class ShardedModelV2(nn.Module):
...
@@ -207,45 +226,51 @@ class ShardedModelV2(nn.Module):
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
cuda
.
current_stream
().
synchronize
()
self
.
reducer
.
free
()
self
.
reducer
.
free
()
#
In case some post bwd hook is not fired
#
3. shard tensors not dealed in the zero hook
if
self
.
shard_param
:
if
self
.
shard_param
:
tensor_list
=
[]
tensor_list
=
[]
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
if
not
p
.
col_attr
.
param_is_sharded
:
if
not
p
.
col_attr
.
param_is_sharded
:
tensor_list
.
append
(
p
.
col_attr
.
sharded_data_tensor
)
tensor_list
.
append
(
p
.
col_attr
.
sharded_data_tensor
)
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
# 4. move sharded param grad payload to param.grad
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
p
.
col_attr
.
bwd_count
=
0
p
.
col_attr
.
bwd_count
=
0
if
not
p
.
requires_grad
:
if
not
p
.
requires_grad
:
continue
continue
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard
# NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient allreducing between process group.
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# If _require_backward_grad_sync is True,
# sync passes, if desired.
# p.grad remains the accumulated unsharded gradient from prior no-sync passes.
# We also allows to interleave no-sync pass with sync passes, if desired.
if
not
self
.
_require_backward_grad_sync
:
if
not
self
.
_require_backward_grad_sync
:
continue
continue
# Write grad back to p.grad and set p.col_attr.grad to None
# Write grad payload kept by sharded param back to p.grad,
# and set p.col_attr.grad to None
# As sharded optimizer only update a shard of param,
# As sharded optimizer only update a shard of param,
# no matter whether we shard param in sharded model
# no matter whether we shard param in sharded model
# We have to make sure the grad is a flat tensor shard
# We have to make sure the grad is a flat tensor shard
# If world size == 1 and sharded
param
,
# If world size == 1 and
param is
sharded,
# the shape `grad` is the same as unsharded param
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
if
self
.
reuse_fp16_shard
:
if
self
.
reuse_fp16_shard
:
grad_payload
=
p
.
col_attr
.
sharded_data_tensor
.
payload
grad_
fp16_
payload
=
p
.
col_attr
.
sharded_data_tensor
.
payload
else
:
else
:
grad_payload
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
.
payload
)
grad_
fp16_
payload
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
.
payload
)
assert
isinstance
(
grad_payload
,
torch
.
Tensor
)
assert
isinstance
(
grad_
fp16_
payload
,
torch
.
Tensor
)
if
p
.
col_attr
.
offload_grad
:
if
p
.
col_attr
.
offload_grad
:
colo_model_data_move_to_cpu
(
grad_payload
)
colo_model_data_move_to_cpu
(
grad_
fp16_
payload
)
if
not
p
.
col_attr
.
fp32_grad
.
is_null
():
if
not
p
.
col_attr
.
fp32_grad
.
is_null
():
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
p
.
col_attr
.
fp32_grad
.
payload
.
add_
(
grad_payload
.
view_as
(
p
.
col_attr
.
fp32_grad
.
payload
))
p
.
col_attr
.
fp32_grad
.
payload
.
add_
(
grad_fp16_payload
.
view_as
(
p
.
col_attr
.
fp32_grad
.
payload
))
grad_payload
=
p
.
col_attr
.
fp32_grad
.
payload
grad_fp16_payload
=
p
.
col_attr
.
fp32_grad
.
payload
p
.
grad
.
data
=
grad_payload
p
.
col_attr
.
fp16_grad
.
set_null
()
p
.
col_attr
.
fp32_grad
.
set_null
()
p
.
col_attr
.
fp32_grad
.
set_null
()
p
.
grad
.
data
=
grad_fp16_payload
p
.
col_attr
.
fp16_grad
.
set_null
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
"""
"""
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
f552b112
...
@@ -79,7 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -79,7 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval
:
float
=
1000
,
growth_interval
:
float
=
1000
,
hysteresis
:
float
=
2
,
hysteresis
:
float
=
2
,
max_scale
:
int
=
2
**
32
,
max_scale
:
int
=
2
**
32
,
use_memory_tracer
=
False
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
...
...
colossalai/zero/sharded_param/sharded_param.py
View file @
f552b112
...
@@ -8,11 +8,8 @@ from .tensorful_state import StatefulTensor, TensorState
...
@@ -8,11 +8,8 @@ from .tensorful_state import StatefulTensor, TensorState
class
ShardedParamV2
(
object
):
class
ShardedParamV2
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
rm_torch_payload
=
False
)
->
None
:
param
:
torch
.
nn
.
Parameter
,
self
.
_sharded_data_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
)
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
rm_torch_payload
=
False
)
->
None
:
self
.
_sharded_data_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
,
process_group
)
self
.
fp16_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
fp16_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
fp32_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
fp32_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
# This attribute must be initialized in ShardedModel
# This attribute must be initialized in ShardedModel
...
...
colossalai/zero/sharded_param/sharded_tensor.py
View file @
f552b112
import
torch
import
torch
import
torch.distributed
as
dist
from
typing
import
Optional
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
,
TensorState
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
,
TensorState
from
typing
import
Optional
class
ShardedTensor
(
StatefulTensor
):
class
ShardedTensor
(
StatefulTensor
):
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
)
->
None
:
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
state
:
TensorState
=
TensorState
.
HOLD
)
->
None
:
r
"""
r
"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
"""
super
().
__init__
(
tensor
)
super
().
__init__
(
tensor
,
state
)
self
.
trans_state
(
TensorState
.
HOLD
)
# kept the shape, numel and dtype of the init tensor.
self
.
_origin_shape
=
tensor
.
shape
self
.
_origin_shape
=
tensor
.
shape
self
.
_origin_numel
=
tensor
.
numel
()
self
.
_origin_numel
=
tensor
.
numel
()
self
.
_origin_dtype
=
tensor
.
dtype
self
.
_origin_dtype
=
tensor
.
dtype
self
.
_is_sharded
=
False
self
.
_is_sharded
=
False
@
property
@
property
...
...
colossalai/zero/sharded_param/tensorful_state.py
View file @
f552b112
from
enum
import
Enum
from
enum
import
Enum
from
logg
ing
import
NullHandler
from
typ
ing
import
Optional
import
torch
import
torch
...
@@ -8,22 +8,22 @@ class TensorState(Enum):
...
@@ -8,22 +8,22 @@ class TensorState(Enum):
HOLD
=
1
HOLD
=
1
HOLD_AFTER_FWD
=
2
HOLD_AFTER_FWD
=
2
HOLD_AFTER_BWD
=
3
HOLD_AFTER_BWD
=
3
COMPUTE
=
4
class
StatefulTensor
(
object
):
class
StatefulTensor
(
object
):
"""A Structure stores a Torch Tensor and labeled states.
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
https://arxiv.org/abs/2108.05818
"""
"""
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
state
:
TensorState
=
TensorState
.
HOLD
)
->
None
:
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
state
:
Optional
[
TensorState
]
=
TensorState
.
HOLD
)
->
None
:
self
.
_state
=
state
self
.
_state
=
state
if
state
is
not
TensorState
.
FREE
:
self
.
_payload
=
tensor
self
.
_payload
=
tensor
else
:
if
self
.
_state
==
TensorState
.
FREE
:
self
.
_payload
=
None
assert
self
.
_payload
is
None
,
f
"payload has to None if
{
self
.
_state
}
"
def
data_ptr
(
self
):
def
data_ptr
(
self
):
if
self
.
_payload
is
None
:
if
self
.
_payload
is
None
:
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
f552b112
...
@@ -48,7 +48,7 @@ def _run_shard_param_v2(rank, world_size, port):
...
@@ -48,7 +48,7 @@ def _run_shard_param_v2(rank, world_size, port):
param
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
2
,
3
))
param
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
2
,
3
))
param_ref
=
deepcopy
(
param
)
param_ref
=
deepcopy
(
param
)
sparam
=
ShardedParamV2
(
param
=
param
,
process_group
=
None
)
sparam
=
ShardedParamV2
(
param
=
param
)
allclose
(
sparam
.
sharded_data_tensor
.
payload
,
param_ref
.
data
)
allclose
(
sparam
.
sharded_data_tensor
.
payload
,
param_ref
.
data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment