Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ee112fe1
Unverified
Commit
ee112fe1
authored
Apr 08, 2022
by
HELSON
Committed by
GitHub
Apr 08, 2022
Browse files
[zero] adapt zero hooks for unsharded module (#699)
parent
896ade15
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
70 additions
and
58 deletions
+70
-58
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+39
-20
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+9
-15
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+2
-4
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+3
-5
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+6
-1
tests/test_moe/test_moe_zero_init.py
tests/test_moe/test_moe_zero_init.py
+0
-1
tests/test_moe/test_moe_zero_model.py
tests/test_moe/test_moe_zero_model.py
+1
-1
tests/test_moe/test_moe_zero_optim.py
tests/test_moe/test_moe_zero_optim.py
+2
-2
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+3
-4
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+1
-1
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+3
-3
tests/test_zero_data_parallel/test_state_dict.py
tests/test_zero_data_parallel/test_state_dict.py
+1
-1
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
ee112fe1
...
...
@@ -36,6 +36,7 @@ class ZeroHook(BaseOpHook):
self
.
_stateful_tensor_mgr
=
stateful_tensor_mgr
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
...
...
@@ -45,12 +46,15 @@ class ZeroHook(BaseOpHook):
for
param
in
module
.
parameters
(
recurse
=
False
):
colo_model_data_tensor_move_inline
(
param
.
colo_attr
.
sharded_data_tensor
,
self
.
computing_device
)
# gather sharded parameters
if
module
.
param_is_sharded
:
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'colo_attr'
)
tensor_list
.
append
(
param
.
colo_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
gather
(
tensor_list
,
self
.
process_group
)
# record memory statistics
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
...
...
@@ -59,18 +63,25 @@ class ZeroHook(BaseOpHook):
assert
param
.
data
.
device
.
type
==
'cuda'
,
f
"PRE FWD param.data must be on CUDA"
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
# change tensor state to HOLD_AFTER_FWD
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_FWD
)
# shard gathered parameters
if
module
.
param_is_sharded
:
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'colo_attr'
)
tensor_list
.
append
(
param
.
colo_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
# remove torch payload
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
remove_torch_payload
()
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
...
...
@@ -80,12 +91,15 @@ class ZeroHook(BaseOpHook):
for
param
in
module
.
parameters
(
recurse
=
False
):
colo_model_data_tensor_move_inline
(
param
.
colo_attr
.
sharded_data_tensor
,
self
.
computing_device
)
# gather sharded parameters
if
module
.
param_is_sharded
:
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'colo_attr'
)
tensor_list
.
append
(
param
.
colo_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
gather
(
tensor_list
,
self
.
process_group
)
# record memory statistics
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
...
...
@@ -94,15 +108,20 @@ class ZeroHook(BaseOpHook):
assert
param
.
data
.
device
.
type
==
'cuda'
,
f
"PRE BWD param.data must be on CUDA"
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
# change tensor state to HOLD_AFTER_BWD
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
# shard gathered parameters
if
module
.
param_is_sharded
:
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'colo_attr'
)
tensor_list
.
append
(
param
.
colo_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
# remove torch payload
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
remove_torch_payload
()
...
...
colossalai/zero/init_ctx/init_context.py
View file @
ee112fe1
...
...
@@ -135,8 +135,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super
().
__init__
()
self
.
shard_strategy
=
shard_strategy
self
.
sharded_param_list
=
[]
self
.
unshard_param_list
=
[]
self
.
param_list
=
[]
self
.
model_numel_tensor
=
model_numel_tensor
self
.
seed
=
seed
self
.
dp_process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
)
...
...
@@ -210,19 +209,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
def
_post_context_exec
(
self
):
"""The callback function when exiting context.
"""
for
param
in
self
.
sharded_param_list
:
assert
hasattr
(
param
,
'colo_attr'
)
param
.
colo_attr
.
remove_torch_payload
()
del
self
.
sharded_param_list
# broadcast replicated no-shard parameters
src_rank
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
DATA
)[
0
]
for
param
in
self
.
unshard_
param_list
:
for
param
in
self
.
param_list
:
assert
hasattr
(
param
,
'colo_attr'
)
if
param
.
is_replicated
:
if
not
param
.
colo_attr
.
param_is_sharded
and
param
.
is_replicated
:
dist
.
broadcast
(
tensor
=
param
.
data
,
src
=
src_rank
,
group
=
self
.
dp_process_group
)
param
.
colo_attr
.
remove_torch_payload
()
del
self
.
unshard_
param_list
del
self
.
param_list
nn
.
init
.
_calculate_fan_in_and_fan_out
=
self
.
nn_fanin_fanout
torch
.
set_rng_state
(
self
.
cpu_rng_state
)
...
...
@@ -264,10 +259,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
param
.
data
=
param
.
colo_attr
.
sharded_data_tensor
.
payload
self
.
sharded_param_list
.
append
(
param
)
else
:
self
.
unshard_param_list
.
append
(
param
)
param
.
data
=
param
.
colo_attr
.
sharded_data_tensor
.
payload
# set param.data to payload
self
.
param_list
.
append
(
param
)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
ee112fe1
...
...
@@ -121,7 +121,7 @@ class ShardedModelV2(nn.Module):
self
.
_ophook_list
=
[
ZeroHook
(
self
.
shard_strategy
,
self
.
_memstats_collector
,
self
.
_stateful_tensor_mgr
,
self
.
process_group
)
]
register_ophooks_recursively
(
self
.
module
,
self
.
_ophook_list
,
filter_fn
=
lambda
m
:
not
m
.
param_is_sharded
)
register_ophooks_recursively
(
self
.
module
,
self
.
_ophook_list
)
self
.
param_hook_mgr
=
BaseParamHookMgr
(
self
.
sharded_params
)
self
.
param_hook_mgr
.
register_backward_hooks
(
self
.
_grad_post_backward_hook
)
...
...
@@ -366,14 +366,12 @@ class ShardedModelV2(nn.Module):
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
for
p
in
self
.
sharded_params
],
self
.
process_group
)
prev_params
=
{}
for
p
in
self
.
sharded_params
:
prev_params
[
p
]
=
p
.
data
p
.
data
=
p
.
colo_attr
.
sharded_data_tensor
.
payload
gathered_state_dict
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
self
.
shard_strategy
.
shard
([
p
.
colo_attr
.
sharded_data_tensor
for
p
in
self
.
sharded_params
],
self
.
process_group
)
for
p
in
self
.
sharded_params
:
p
.
data
=
prev_params
[
p
]
p
.
colo_attr
.
remove_torch_payload
()
return
gathered_state_dict
def
load_state_dict
(
self
,
state_dict
:
'OrderedDict[str, torch.Tensor]'
,
strict
:
bool
=
True
):
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
ee112fe1
...
...
@@ -268,10 +268,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
p
.
data
=
self
.
master_params
[
p
].
payload
p
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
colo_model_tensor_clone
(
p
.
half
(),
torch
.
cuda
.
current_device
()))
if
not
p
.
colo_attr
.
param_is_sharded
:
# FIXME(hhc): add hook for unsharded parameters
p
.
data
=
p
.
colo_attr
.
sharded_data_tensor
.
payload
p
.
colo_attr
.
remove_torch_payload
()
def
sync_grad
(
self
):
pass
...
...
@@ -351,10 +348,11 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
colo_model_tensor_clone
(
p
.
half
(),
p
.
colo_attr
.
sharded_data_tensor
.
device
))
p
.
colo_attr
.
remove_torch_payload
()
if
not
is_param_sharded
and
not
self
.
keep_unshard
:
# We gather full fp16 param here
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
p
.
data
=
p
.
colo_attr
.
sharded_data_tensor
.
payload
self
.
master_params
[
p
].
trans_state
(
TensorState
.
HOLD
)
p
.
colo_attr
.
saved_grad
.
set_null
()
colossalai/zero/sharded_param/sharded_param.py
View file @
ee112fe1
...
...
@@ -5,6 +5,11 @@ from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage
from
.tensorful_state
import
StatefulTensor
,
TensorState
from
typing
import
List
# use this tensor as empty data point for parameters
# we do not want users use param.data when its torch payload is removed
# empty tensor is expected to raise error when get used
FAKE_EMPTY_TENSOR
=
torch
.
BoolTensor
([],
device
=
'cpu'
)
class
ShardedParamV2
(
object
):
...
...
@@ -29,7 +34,7 @@ class ShardedParamV2(object):
return
[
self
.
_sharded_data_tensor
]
def
remove_torch_payload
(
self
):
self
.
param
.
data
=
torch
.
empty
([],
dtype
=
self
.
param
.
dtype
,
device
=
self
.
param
.
devic
e
)
self
.
param
.
data
=
FAKE_EMPTY_TENSOR
.
to
(
self
.
_sharded_data_tensor
.
device
,
self
.
_sharded_data_tensor
.
dtyp
e
)
@
property
def
sharded_data_tensor
(
self
):
...
...
tests/test_moe/test_moe_zero_init.py
View file @
ee112fe1
...
...
@@ -66,7 +66,6 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
# the parameters in moe experts and its gate should not be sharded
if
(
'experts'
in
name
)
or
(
'gate'
in
name
)
or
(
'residual_combine'
in
name
):
assert
not
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
colo_attr
.
sharded_data_tensor
.
data_ptr
()
==
param
.
data
.
data_ptr
()
else
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
...
...
tests/test_moe/test_moe_zero_model.py
View file @
ee112fe1
...
...
@@ -37,7 +37,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
# check whether parameters are identical in ddp
for
name
,
p
in
zero_model
.
named_parameters
():
if
not
p
.
colo_attr
.
param_is_sharded
and
p
.
is_replicated
:
assert_equal_in_group
(
p
.
data
)
assert_equal_in_group
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
)
model
=
MoeModel
().
half
()
col_model_deepcopy
(
zero_model
,
model
)
...
...
tests/test_moe/test_moe_zero_optim.py
View file @
ee112fe1
...
...
@@ -74,7 +74,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
# check whether parameters are identical in ddp
for
name
,
p
in
zero_model
.
named_parameters
():
if
not
p
.
colo_attr
.
param_is_sharded
and
p
.
is_replicated
:
assert_equal_in_group
(
p
.
data
.
to
(
get_current_device
()))
assert_equal_in_group
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
.
to
(
get_current_device
()))
model
=
MoeModel
().
half
()
col_model_deepcopy
(
zero_model
,
model
)
...
...
@@ -99,7 +99,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
for
(
n
,
p
),
zp
in
zip
(
apex_model
.
named_parameters
(),
zero_model
.
parameters
()):
if
'gate'
in
n
:
p
.
data
=
p
.
float
()
p
.
data
.
copy_
(
zp
.
data
)
p
.
data
.
copy_
(
zp
.
colo_attr
.
sharded_data_tensor
.
payload
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
5
:
...
...
tests/test_zero_data_parallel/common.py
View file @
ee112fe1
...
...
@@ -126,9 +126,6 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
rank
=
dist
.
get_rank
()
for
(
name
,
p
),
(
zero_name
,
zero_p
)
in
zip
(
model
.
named_parameters
(),
zero_model
.
named_parameters
()):
if
zero_p
.
colo_attr
.
param_is_sharded
:
if
reuse_fp16_shard
:
zero_p
=
zero_p
.
data
.
to
(
p
.
device
).
float
()
else
:
zero_p
=
zero_p
.
colo_attr
.
sharded_data_tensor
.
payload
.
to
(
p
.
device
).
float
()
chunks
=
torch
.
flatten
(
p
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
...
...
@@ -136,6 +133,8 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
p
=
chunks
[
rank
].
float
()
if
zero_p
.
size
(
0
)
>
p
.
size
(
0
):
zero_p
=
zero_p
[:
p
.
size
(
0
)]
else
:
zero_p
=
zero_p
.
colo_attr
.
sharded_data_tensor
.
payload
assert
p
.
dtype
==
zero_p
.
dtype
assert
allclose
(
p
,
zero_p
,
loose
=
loose
),
f
'
{
p
}
vs
{
zero_p
}
'
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
ee112fe1
...
...
@@ -21,7 +21,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
@
parameterize
(
"enable_autocast"
,
[
True
])
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
@
parameterize
(
"shard_strategy_class"
,
[
BucketTensorShardStrategy
])
def
run_model_test
(
enable_autocast
,
shard_strategy_class
):
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
,
'no_leaf_module'
]
shard_strategy
=
shard_strategy_class
()
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
ee112fe1
...
...
@@ -58,15 +58,15 @@ def _run_shard_param_v2(rank, world_size, port):
assert
cpu_mem_use
==
2
*
3
*
4
*
2
,
f
"cpu_mem_use:
{
cpu_mem_use
}
"
sparam
.
remove_torch_payload
()
assert
(
param
.
data
.
numel
()
==
1
)
assert
(
param
.
data
.
numel
()
==
0
)
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
# 4 is size of dummy tensor of param.data
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
assert
cpu_mem_use
==
2
*
3
*
4
*
2
sparam
.
saved_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
sparam
.
remove_torch_payload
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
assert
cpu_mem_use
==
2
*
3
*
4
*
2
assert
cuda_mem_use
==
0
# append a grad to torch param
...
...
tests/test_zero_data_parallel/test_state_dict.py
View file @
ee112fe1
...
...
@@ -56,4 +56,4 @@ def test_zero_state_dict(world_size):
if
__name__
==
'__main__'
:
test_zero_state_dict
(
2
,
TensorShardStrategy
)
test_zero_state_dict
(
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment