Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
214da761
Unverified
Commit
214da761
authored
Mar 30, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 30, 2022
Browse files
[zero] add stateful tensor (#549)
parent
107b99dd
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
123 additions
and
59 deletions
+123
-59
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+2
-2
colossalai/zero/sharded_model/_utils.py
colossalai/zero/sharded_model/_utils.py
+8
-1
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+11
-7
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+7
-6
colossalai/zero/sharded_param/sharded_tensor.py
colossalai/zero/sharded_param/sharded_tensor.py
+8
-38
colossalai/zero/sharded_param/tensorful_state.py
colossalai/zero/sharded_param/tensorful_state.py
+81
-0
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+6
-5
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
214da761
...
...
@@ -64,8 +64,8 @@ class ZeroHook(BaseOpHook):
if
param
.
grad
is
not
None
:
if
param
.
col_attr
.
bwd_count
==
0
:
# We haven't stored local accumulated grad yet
assert
param
.
col_attr
.
fp32_grad
is
None
param
.
col_attr
.
fp32_grad
=
param
.
grad
.
data
assert
param
.
col_attr
.
fp32_grad
.
is
_null
()
param
.
col_attr
.
fp32_grad
.
reset_payload
(
param
.
grad
.
data
)
param
.
grad
=
None
else
:
# We have stored local accumulated grad
...
...
colossalai/zero/sharded_model/_utils.py
View file @
214da761
...
...
@@ -2,6 +2,8 @@ from typing import Any, Callable, List, Tuple
import
torch
import
torch.nn.functional
as
F
from
typing
import
Union
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
def
get_gradient_predivide_factor
(
world_size
:
int
)
->
float
:
...
...
@@ -30,12 +32,17 @@ def alloc_storage(data: torch.Tensor, size: torch.Size) -> None:
def
cast_tensor_to_fp16
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
isinstance
(
tensor
,
StatefulTensor
):
tensor
=
tensor
.
payload
if
torch
.
is_floating_point
(
tensor
)
and
tensor
.
dtype
is
torch
.
float32
:
return
tensor
.
half
()
return
tensor
def
cast_tensor_to_fp32
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
cast_tensor_to_fp32
(
tensor
:
Union
[
torch
.
Tensor
,
StatefulTensor
])
->
torch
.
Tensor
:
if
isinstance
(
tensor
,
StatefulTensor
):
tensor
=
tensor
.
payload
if
torch
.
is_floating_point
(
tensor
)
and
tensor
.
dtype
is
torch
.
float16
:
return
tensor
.
float
()
return
tensor
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
214da761
...
...
@@ -25,6 +25,7 @@ from torch.nn.parameter import Parameter
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
)
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
class
ShardedModelV2
(
nn
.
Module
):
...
...
@@ -233,16 +234,17 @@ class ShardedModelV2(nn.Module):
if
self
.
reuse_fp16_shard
:
grad_payload
=
p
.
col_attr
.
sharded_data_tensor
.
payload
else
:
grad_payload
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
)
grad_payload
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
.
payload
)
assert
isinstance
(
grad_payload
,
torch
.
Tensor
)
if
p
.
col_attr
.
offload_grad
:
colo_model_data_move_to_cpu
(
grad_payload
)
if
p
.
col_attr
.
fp32_grad
is
not
None
:
if
not
p
.
col_attr
.
fp32_grad
.
is
_null
()
:
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
p
.
col_attr
.
fp32_grad
.
add_
(
grad_payload
.
view_as
(
p
.
col_attr
.
fp32_grad
))
grad_payload
=
p
.
col_attr
.
fp32_grad
p
.
col_attr
.
fp32_grad
.
payload
.
add_
(
grad_payload
.
view_as
(
p
.
col_attr
.
fp32_grad
.
payload
))
grad_payload
=
p
.
col_attr
.
fp32_grad
.
payload
p
.
grad
.
data
=
grad_payload
p
.
col_attr
.
fp16_grad
=
None
p
.
col_attr
.
fp32_grad
=
None
p
.
col_attr
.
fp16_grad
.
set_null
()
p
.
col_attr
.
fp32_grad
.
set_null
()
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
...
@@ -293,6 +295,8 @@ class ShardedModelV2(nn.Module):
return
empty_grad
def
_reduce_scatter_callback
(
self
,
param
:
Parameter
,
reduced_grad
:
torch
.
Tensor
)
->
None
:
assert
isinstance
(
reduced_grad
,
torch
.
Tensor
),
f
"_reduce_scatter_callback accept reduced_grad as
{
type
(
reduced_grad
)
}
"
reduced_grad
=
reduced_grad
.
view
(
-
1
)
if
self
.
gradient_postdivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
...
...
@@ -301,7 +305,7 @@ class ShardedModelV2(nn.Module):
param
.
col_attr
.
sharded_data_tensor
.
reset_payload
(
reduced_grad
.
data
)
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
=
True
else
:
param
.
col_attr
.
fp16_grad
=
reduced_grad
.
data
param
.
col_attr
.
fp16_grad
=
StatefulTensor
(
reduced_grad
.
data
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
sharded_data_tensor
for
p
in
self
.
module
.
parameters
()],
...
...
colossalai/zero/sharded_param/sharded_param.py
View file @
214da761
...
...
@@ -3,6 +3,7 @@ import torch.distributed as dist
from
colossalai.zero.sharded_param
import
ShardedTensor
from
typing
import
Optional
,
Tuple
from
colossalai.utils.memory_utils.utils
import
colo_tensor_mem_usage
from
.tensorful_state
import
StatefulTensor
,
TensorState
class
ShardedParamV2
(
object
):
...
...
@@ -12,8 +13,8 @@ class ShardedParamV2(object):
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
rm_torch_payload
=
False
)
->
None
:
self
.
_sharded_data_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
,
process_group
)
self
.
fp16_grad
:
Optional
[
torch
.
Tensor
]
=
None
self
.
fp32_grad
:
Optional
[
torch
.
Tensor
]
=
None
self
.
fp16_grad
:
Stateful
Tensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
fp32_grad
:
Stateful
Tensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
# This attribute must be initialized in ShardedModel
self
.
offload_grad
:
bool
=
False
...
...
@@ -64,12 +65,12 @@ class ShardedParamV2(object):
_update_mem_use
(
self
.
sharded_data_tensor
.
payload
)
address_set
.
add
(
self
.
sharded_data_tensor
.
payload
.
data_ptr
())
if
self
.
fp16_grad
is
not
None
and
self
.
fp16_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
fp16_grad
)
if
not
self
.
fp16_grad
.
is
_null
()
and
self
.
fp16_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
fp16_grad
.
payload
)
address_set
.
add
(
self
.
fp16_grad
.
data_ptr
())
if
self
.
fp32_grad
is
not
None
and
self
.
fp32_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
fp32_grad
)
if
not
self
.
fp32_grad
.
is
_null
()
and
self
.
fp32_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
fp32_grad
.
payload
)
address_set
.
add
(
self
.
fp32_grad
.
data_ptr
())
if
self
.
param
.
data
is
not
None
and
self
.
param
.
data
.
data_ptr
()
not
in
address_set
:
...
...
colossalai/zero/sharded_param/sharded_tensor.py
View file @
214da761
import
torch
import
torch.distributed
as
dist
from
typing
import
Optional
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
,
TensorState
class
ShardedTensor
(
object
):
class
ShardedTensor
(
StatefulTensor
):
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
)
->
None
:
r
"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
self
.
_payload
=
tensor
self
.
process_group
=
process_group
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
local_rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
_is_sharded
=
False
super
().
__init__
(
tensor
)
self
.
trans_state
(
TensorState
.
HOLD
)
self
.
_origin_shape
=
tensor
.
shape
self
.
_origin_numel
=
tensor
.
numel
()
self
.
_origin_dtype
=
tensor
.
dtype
self
.
_is_sharded
=
False
@
property
def
origin_numel
(
self
):
def
origin_numel
(
self
)
->
int
:
return
self
.
_origin_numel
@
property
def
origin_shape
(
self
):
def
origin_shape
(
self
)
->
int
:
return
self
.
_origin_shape
@
property
...
...
@@ -34,33 +34,3 @@ class ShardedTensor(object):
@
is_sharded
.
setter
def
is_sharded
(
self
,
flag
:
bool
):
self
.
_is_sharded
=
flag
@
property
def
payload
(
self
):
return
self
.
_payload
def
copy_payload
(
self
,
tensor
):
self
.
_payload
.
view
(
-
1
).
copy_
(
tensor
.
view
(
-
1
))
def
reset_payload
(
self
,
tensor
):
del
self
.
_payload
self
.
_payload
=
tensor
@
property
def
device
(
self
):
return
self
.
_payload
.
device
@
property
def
dtype
(
self
):
assert
self
.
_payload
.
dtype
==
self
.
_origin_dtype
return
self
.
_origin_dtype
def
to
(
self
,
device
:
torch
.
device
):
raise
RuntimeError
(
"Use colo_model_tensor_move install of call .to() on ShardedTensor"
)
def
to_
(
self
,
device
:
torch
.
device
):
raise
RuntimeError
(
"Use colo_model_tensor_move install of call .to_() on ShardedTensor"
)
@
property
def
shape
(
self
):
return
self
.
_payload
.
shape
colossalai/zero/sharded_param/tensorful_state.py
0 → 100644
View file @
214da761
from
enum
import
Enum
from
logging
import
NullHandler
import
torch
class
TensorState
(
Enum
):
FREE
=
0
HOLD
=
1
HOLD_AFTER_FWD
=
2
HOLD_AFTER_BWD
=
3
class
StatefulTensor
(
object
):
"""A Structure stores a Torch Tensor and labeled states.
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
state
:
TensorState
=
TensorState
.
HOLD
)
->
None
:
self
.
_state
=
state
if
state
is
not
TensorState
.
FREE
:
self
.
_payload
=
tensor
else
:
self
.
_payload
=
None
def
data_ptr
(
self
):
if
self
.
_payload
is
None
:
return
None
return
self
.
_payload
.
data_ptr
()
@
property
def
state
(
self
)
->
TensorState
:
return
self
.
_state
def
set_null
(
self
)
->
None
:
self
.
_state
=
TensorState
.
FREE
self
.
_payload
=
None
def
is_null
(
self
)
->
bool
:
if
self
.
_state
==
TensorState
.
FREE
:
assert
self
.
_payload
is
None
return
True
return
False
def
trans_state
(
self
,
state
:
TensorState
)
->
None
:
self
.
_state
=
state
if
state
==
TensorState
.
FREE
:
self
.
_payload
=
None
@
property
def
payload
(
self
)
->
int
:
return
self
.
_payload
def
copy_payload
(
self
,
tensor
)
->
int
:
self
.
_payload
.
view
(
-
1
).
copy_
(
tensor
.
view
(
-
1
))
def
reset_payload
(
self
,
tensor
)
->
int
:
del
self
.
_payload
self
.
_payload
=
tensor
self
.
trans_state
(
TensorState
.
HOLD
)
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
_payload
.
device
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
assert
self
.
_payload
.
dtype
==
self
.
_origin_dtype
return
self
.
_origin_dtype
def
to
(
self
,
device
:
torch
.
device
):
raise
RuntimeError
(
"Use colo_model_tensor_move install of call .to() on ShardedTensor"
)
def
to_
(
self
,
device
:
torch
.
device
):
raise
RuntimeError
(
"Use colo_model_tensor_move install of call .to_() on ShardedTensor"
)
@
property
def
shape
(
self
):
return
self
.
_payload
.
shape
tests/test_zero_data_parallel/test_shard_param.py
View file @
214da761
...
...
@@ -12,6 +12,7 @@ from colossalai.zero.sharded_param import ShardedTensor
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
colossalai.testing
import
rerun_on_exception
from
tests.test_zero_data_parallel.common
import
CONFIG
,
allclose
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
...
...
@@ -52,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port):
allclose
(
sparam
.
sharded_data_tensor
.
payload
,
param_ref
.
data
)
# Test get memory usage
sparam
.
fp32_grad
=
torch
.
randn
(
2
,
3
)
sparam
.
fp32_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
)
)
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
,
f
"cpu_mem_use:
{
cpu_mem_use
}
"
...
...
@@ -62,13 +63,13 @@ def _run_shard_param_v2(rank, world_size, port):
# 4 is size of dummy tensor of param.data
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
sparam
.
fp16_grad
=
torch
.
randn
(
2
,
3
).
cuda
().
half
()
sparam
.
fp16_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
).
cuda
().
half
()
)
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
assert
cuda_mem_use
==
2
*
3
*
2
sparam
.
fp16_grad
=
None
sparam
.
fp32_grad
=
torch
.
randn
(
2
,
3
)
sparam
.
fp16_grad
=
StatefulTensor
(
None
)
sparam
.
fp32_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
)
)
sparam
.
remove_torch_payload
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
...
...
@@ -82,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port):
assert
cuda_mem_use
==
0
# reuse torch grad for sparam
sparam
.
fp32_grad
=
param
.
grad
sparam
.
fp32_grad
=
StatefulTensor
(
param
.
grad
)
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
assert
cuda_mem_use
==
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment