Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
22c4b88d
"vscode:/vscode.git/clone" did not exist on "d3f5ce9efb35bf9e292aa041a3e98b737cbb68ee"
Unverified
Commit
22c4b88d
authored
Apr 13, 2022
by
HELSON
Committed by
GitHub
Apr 13, 2022
Browse files
[zero] refactor ShardedParamV2 for convenience (#742)
parent
340e59f9
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
98 additions
and
61 deletions
+98
-61
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+3
-3
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+7
-7
colossalai/zero/sharded_model/utils.py
colossalai/zero/sharded_model/utils.py
+1
-1
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+7
-8
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+42
-9
colossalai/zero/sharded_param/sharded_tensor.py
colossalai/zero/sharded_param/sharded_tensor.py
+6
-0
colossalai/zero/sharded_param/tensorful_state.py
colossalai/zero/sharded_param/tensorful_state.py
+10
-11
colossalai/zero/utils/zero_hook.py
colossalai/zero/utils/zero_hook.py
+4
-4
tests/test_moe/test_moe_zero_init.py
tests/test_moe/test_moe_zero_init.py
+3
-3
tests/test_moe/test_moe_zero_model.py
tests/test_moe/test_moe_zero_model.py
+1
-1
tests/test_moe/test_moe_zero_optim.py
tests/test_moe/test_moe_zero_optim.py
+2
-2
tests/test_zero/common.py
tests/test_zero/common.py
+4
-4
tests/test_zero/test_found_inf.py
tests/test_zero/test_found_inf.py
+1
-1
tests/test_zero/test_init_context.py
tests/test_zero/test_init_context.py
+2
-2
tests/test_zero/test_shard_param.py
tests/test_zero/test_shard_param.py
+4
-4
tests/test_zero/test_stateful_tensor_mgr.py
tests/test_zero/test_stateful_tensor_mgr.py
+1
-1
No files found.
colossalai/zero/init_ctx/init_context.py
View file @
22c4b88d
...
...
@@ -215,7 +215,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
assert
hasattr
(
param
,
'colo_attr'
)
if
not
param
.
colo_attr
.
param_is_sharded
and
param
.
colo_attr
.
is_replicated
:
dist
.
broadcast
(
tensor
=
param
.
data
,
src
=
src_rank
,
group
=
self
.
dp_process_group
)
param
.
colo_attr
.
remove_torch_payload
()
param
.
colo_attr
.
set_data_none
()
del
self
.
param_list
...
...
@@ -252,11 +252,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
target_device
)
param
.
colo_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
False
)
param
.
colo_attr
=
ShardedParamV2
(
param
,
set_data_none
=
False
)
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
param
.
data
=
param
.
colo_attr
.
sharded_data_tensor
.
payload
# set param.data to payload
param
.
data
=
param
.
colo_attr
.
data_
payload
# set param.data to payload
# mark whether the param is replicated
param
.
colo_attr
.
is_replicated
=
self
.
is_replicated
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
22c4b88d
...
...
@@ -260,7 +260,7 @@ class ShardedModelV2(nn.Module):
if
not
p
.
colo_attr
.
param_is_sharded
:
tensor_list
.
append
(
p
.
colo_attr
.
sharded_data_tensor
)
p
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
p
.
colo_attr
.
remove_torch_payload
()
p
.
colo_attr
.
set_data_none
()
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
# 4. set all parameters' grad to None
...
...
@@ -357,8 +357,8 @@ class ShardedModelV2(nn.Module):
assert
param
.
colo_attr
.
saved_grad
.
is_null
(
),
'Gradien accumulation is not supported when reuse_fp16_shard=True'
param
.
colo_attr
.
saved_grad
.
reset_payload
(
grad
)
param
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
grad
)
# release the memory of param
param
.
colo_attr
.
reset
_grad
_payload
(
grad
)
param
.
colo_attr
.
reset
_grad
_payload
(
grad
)
# release the memory of param
if
param
.
colo_attr
.
is_replicated
:
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
=
True
...
...
@@ -367,9 +367,9 @@ class ShardedModelV2(nn.Module):
fp32_grad
=
cast_tensor_to_fp32
(
grad
)
if
param
.
colo_attr
.
saved_grad
.
is_null
():
param
.
colo_attr
.
saved_grad
.
reset_payload
(
fp32_grad
)
param
.
colo_attr
.
reset
_grad
_payload
(
fp32_grad
)
else
:
param
.
colo_attr
.
saved_
grad
.
payload
.
add_
(
fp32_grad
.
view_as
(
param
.
colo_attr
.
saved_
grad
.
payload
))
param
.
colo_attr
.
grad
_
payload
.
add_
(
fp32_grad
.
view_as
(
param
.
colo_attr
.
grad
_
payload
))
# keep saved_grad in HOLD state
param
.
colo_attr
.
saved_grad
.
trans_state
(
TensorState
.
HOLD
)
...
...
@@ -377,11 +377,11 @@ class ShardedModelV2(nn.Module):
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
for
p
in
self
.
sharded_params
],
self
.
process_group
)
for
p
in
self
.
sharded_params
:
p
.
data
=
p
.
colo_attr
.
sharded_data_tensor
.
payload
p
.
data
=
p
.
colo_attr
.
data_
payload
gathered_state_dict
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
self
.
shard_strategy
.
shard
([
p
.
colo_attr
.
sharded_data_tensor
for
p
in
self
.
sharded_params
],
self
.
process_group
)
for
p
in
self
.
sharded_params
:
p
.
colo_attr
.
remove_torch_payload
()
p
.
colo_attr
.
set_data_none
()
return
gathered_state_dict
def
load_state_dict
(
self
,
state_dict
:
'OrderedDict[str, torch.Tensor]'
,
strict
:
bool
=
True
):
...
...
colossalai/zero/sharded_model/utils.py
View file @
22c4b88d
...
...
@@ -14,6 +14,6 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
shard_flag
=
zero_param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
if
shard_flag
:
sharded_model
.
shard_strategy
.
gather
([
zero_param
.
colo_attr
.
sharded_data_tensor
])
param
.
data
=
copy
.
deepcopy
(
zero_param
.
colo_attr
.
sharded_data_tensor
.
payload
)
param
.
data
=
copy
.
deepcopy
(
zero_param
.
colo_attr
.
data_
payload
)
if
shard_flag
:
sharded_model
.
shard_strategy
.
shard
([
zero_param
.
colo_attr
.
sharded_data_tensor
])
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
22c4b88d
...
...
@@ -266,8 +266,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if
shard_flag
:
# we always shard replicated paramters
self
.
shard_strategy
.
shard
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
master_params
[
p
]
=
StatefulTensor
(
cast_tensor_to_fp32
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
.
to
(
self
.
device
)))
self
.
master_params
[
p
]
=
StatefulTensor
(
cast_tensor_to_fp32
(
p
.
colo_attr
.
data_payload
.
to
(
self
.
device
)))
if
shard_flag
:
# In this branch, there's no need to shard param
# So we gather here
...
...
@@ -296,10 +295,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.colo_attr.saved_grad.payload here
p
.
data
=
p
.
colo_attr
.
saved_
grad
.
payload
p
.
grad
=
p
.
colo_attr
.
saved_
grad
.
payload
p
.
data
=
p
.
colo_attr
.
grad
_
payload
p
.
grad
=
p
.
colo_attr
.
grad
_
payload
# Set p.data to empty tensor, in case of memory leaking
p
.
colo_attr
.
remove_torch_payload
()
p
.
colo_attr
.
set_data_none
()
def
_point_param_fp16_to_master_param
(
self
):
# assign master param pointers to p.data.
...
...
@@ -325,9 +324,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p
.
data
=
self
.
master_params
[
p
].
payload
p
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
colo_model_tensor_clone
(
p
.
half
(),
p
.
colo_attr
.
sharded_data_tensor
.
device
))
p
.
colo_attr
.
remove_torch_payload
()
p
.
colo_attr
.
reset
_data
_payload
(
colo_model_tensor_clone
(
p
.
half
()
.
detach
()
,
p
.
colo_attr
.
sharded_data_tensor
.
device
))
p
.
colo_attr
.
set_data_none
()
if
p
.
colo_attr
.
keep_not_shard
and
p
.
colo_attr
.
is_replicated
:
# We gather full fp16 param here
...
...
colossalai/zero/sharded_param/sharded_param.py
View file @
22c4b88d
...
...
@@ -10,10 +10,20 @@ from typing import List
# empty tensor is expected to raise error when get used
FAKE_EMPTY_TENSOR
=
torch
.
BoolTensor
([],
device
=
'cpu'
)
EMPTY_TENSOR_DICT
=
{}
def
get_empty_tensor
(
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
key
=
(
device
,
dtype
)
if
key
not
in
EMPTY_TENSOR_DICT
:
EMPTY_TENSOR_DICT
[
key
]
=
FAKE_EMPTY_TENSOR
.
to
(
device
,
dtype
)
return
EMPTY_TENSOR_DICT
[
key
]
class
ShardedParamV2
(
object
):
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
rm_torch_payload
=
False
)
->
None
:
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
set_data_none
:
bool
=
False
)
->
None
:
self
.
_sharded_data_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
)
self
.
saved_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
# This attribute must be initialized in ShardedModel
...
...
@@ -25,24 +35,47 @@ class ShardedParamV2(object):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# So we can not empty the .data at this time
self
.
param
=
param
if
rm_torch_payload
:
self
.
remove_torch_payload
()
if
set_data_none
:
self
.
set_data_none
()
def
get_payload_tensors
(
self
)
->
List
[
StatefulTensor
]:
"""returns stateful tensors kept by this class.
"""
return
[
self
.
_sharded_data_tensor
]
def
remove_torch_payload
(
self
):
self
.
param
.
data
=
FAKE_EMPTY_TENSOR
.
to
(
self
.
_sharded_data_tensor
.
device
,
self
.
_sharded_data_tensor
.
dtype
)
def
set_data_none
(
self
):
self
.
param
.
data
=
get_empty_tensor
(
self
.
sharded_data_tensor
.
device
,
self
.
sharded_data_tensor
.
dtype
)
def
set_grad_none
(
self
):
self
.
saved_grad
.
set_null
()
@
property
def
sharded_data_tensor
(
self
):
return
self
.
_sharded_data_tensor
@
property
def
data_payload
(
self
):
return
self
.
sharded_data_tensor
.
payload
@
property
def
grad_payload
(
self
):
assert
not
self
.
saved_grad
.
is_null
()
return
self
.
saved_grad
.
payload
@
property
def
param_is_sharded
(
self
):
return
self
.
_sharded_data_tensor
.
is_sharded
return
self
.
sharded_data_tensor
.
is_sharded
def
reset_data_payload
(
self
,
tensor
:
torch
.
Tensor
):
assert
type
(
tensor
)
is
torch
.
Tensor
assert
tensor
.
requires_grad
is
False
self
.
sharded_data_tensor
.
reset_payload
(
tensor
)
self
.
set_data_none
()
def
reset_grad_payload
(
self
,
tensor
:
torch
.
Tensor
):
assert
type
(
tensor
)
is
torch
.
Tensor
assert
tensor
.
requires_grad
is
False
self
.
saved_grad
.
reset_payload
(
tensor
)
def
get_memory_usage
(
self
)
->
Tuple
[
int
,
int
]:
"""
...
...
@@ -63,11 +96,11 @@ class ShardedParamV2(object):
cpu_mem_use
+=
t_cpu
address_set
=
set
()
_update_mem_use
(
self
.
sharded_data_tensor
.
payload
)
address_set
.
add
(
self
.
sharded_data_tensor
.
payload
.
data_ptr
())
_update_mem_use
(
self
.
data_
payload
)
address_set
.
add
(
self
.
data_
payload
.
data_ptr
())
if
not
self
.
saved_grad
.
is_null
()
and
self
.
saved_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
saved_
grad
.
payload
)
_update_mem_use
(
self
.
grad
_
payload
)
address_set
.
add
(
self
.
saved_grad
.
data_ptr
())
if
self
.
param
.
data
is
not
None
and
self
.
param
.
data
.
data_ptr
()
not
in
address_set
:
...
...
colossalai/zero/sharded_param/sharded_tensor.py
View file @
22c4b88d
...
...
@@ -9,6 +9,7 @@ class ShardedTensor(StatefulTensor):
r
"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
assert
tensor
.
requires_grad
is
False
super
().
__init__
(
tensor
,
state
)
# kept the shape, numel and dtype of the init tensor.
...
...
@@ -17,6 +18,11 @@ class ShardedTensor(StatefulTensor):
self
.
_origin_dtype
=
tensor
.
dtype
self
.
_is_sharded
=
False
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
assert
self
.
_payload
.
dtype
==
self
.
_origin_dtype
return
self
.
_payload
.
dtype
@
property
def
origin_numel
(
self
)
->
int
:
return
self
.
_origin_numel
...
...
colossalai/zero/sharded_param/tensorful_state.py
View file @
22c4b88d
...
...
@@ -19,11 +19,11 @@ class StatefulTensor(object):
https://arxiv.org/abs/2108.05818
"""
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
state
:
Optional
[
TensorState
]
=
TensorState
.
HOLD
)
->
None
:
def
__init__
(
self
,
tensor
:
Optional
[
torch
.
Tensor
]
,
state
:
Optional
[
TensorState
]
=
TensorState
.
HOLD
)
->
None
:
self
.
_state
=
state
self
.
_payload
=
tensor
if
self
.
_state
==
TensorState
.
FREE
:
assert
self
.
_payload
is
None
,
f
"payload has to None if
{
self
.
_state
}
"
assert
self
.
_payload
is
None
,
f
"payload has to None if
state is
{
self
.
_state
}
"
def
data_ptr
(
self
):
if
self
.
_payload
is
None
:
...
...
@@ -50,13 +50,13 @@ class StatefulTensor(object):
self
.
_payload
=
None
@
property
def
payload
(
self
)
->
int
:
def
payload
(
self
)
->
Optional
[
torch
.
Tensor
]
:
return
self
.
_payload
def
copy_payload
(
self
,
tensor
)
->
int
:
def
copy_payload
(
self
,
tensor
)
->
None
:
self
.
_payload
.
view
(
-
1
).
copy_
(
tensor
.
view
(
-
1
))
def
reset_payload
(
self
,
tensor
)
->
int
:
def
reset_payload
(
self
,
tensor
)
->
None
:
del
self
.
_payload
self
.
_payload
=
tensor
self
.
trans_state
(
TensorState
.
HOLD
)
...
...
@@ -67,15 +67,14 @@ class StatefulTensor(object):
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
assert
self
.
_payload
.
dtype
==
self
.
_origin_dtype
return
self
.
_origin_dtype
return
self
.
_payload
.
dtype
@
property
def
shape
(
self
):
return
self
.
_payload
.
shape
def
to
(
self
,
device
:
torch
.
device
):
raise
RuntimeError
(
"Use colo_model_tensor_move install of call .to() on ShardedTensor"
)
def
to_
(
self
,
device
:
torch
.
device
):
raise
RuntimeError
(
"Use colo_model_tensor_move install of call .to_() on ShardedTensor"
)
@
property
def
shape
(
self
):
return
self
.
_payload
.
shape
colossalai/zero/utils/zero_hook.py
View file @
22c4b88d
...
...
@@ -60,7 +60,7 @@ class ZeroHook(BaseOpHook):
self
.
_memstarts_collector
.
sample_memstats
()
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
data
=
param
.
colo_attr
.
sharded_data_tensor
.
payload
param
.
data
=
param
.
colo_attr
.
data_
payload
assert
param
.
data
.
device
.
type
==
'cuda'
,
f
"PRE FWD param.data must be on CUDA"
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
...
...
@@ -79,7 +79,7 @@ class ZeroHook(BaseOpHook):
# remove torch payload
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
remove_torch_payload
()
param
.
colo_attr
.
set_data_none
()
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
...
...
@@ -105,7 +105,7 @@ class ZeroHook(BaseOpHook):
self
.
_memstarts_collector
.
sample_memstats
()
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
data
=
param
.
colo_attr
.
sharded_data_tensor
.
payload
param
.
data
=
param
.
colo_attr
.
data_
payload
assert
param
.
data
.
device
.
type
==
'cuda'
,
f
"PRE BWD param.data must be on CUDA"
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
...
...
@@ -124,7 +124,7 @@ class ZeroHook(BaseOpHook):
# remove torch payload
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
remove_torch_payload
()
param
.
colo_attr
.
set_data_none
()
def
pre_iter
(
self
):
pass
...
...
tests/test_moe/test_moe_zero_init.py
View file @
22c4b88d
...
...
@@ -77,10 +77,10 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
assert
param
.
colo_attr
.
is_replicated
if
param
.
colo_attr
.
param_is_sharded
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
assert
param
.
colo_attr
.
data_
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
colo_attr
.
data_
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
else
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
'cuda'
assert
param
.
colo_attr
.
data_
payload
.
device
.
type
==
'cuda'
def
_run_dist
(
rank
,
world_size
,
port
):
...
...
tests/test_moe/test_moe_zero_model.py
View file @
22c4b88d
...
...
@@ -37,7 +37,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
# check whether parameters are identical in ddp
for
name
,
p
in
zero_model
.
named_parameters
():
if
not
p
.
colo_attr
.
param_is_sharded
and
p
.
colo_attr
.
is_replicated
:
assert_equal_in_group
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
)
assert_equal_in_group
(
p
.
colo_attr
.
data_
payload
)
model
=
MoeModel
(
checkpoint
=
True
).
half
()
col_model_deepcopy
(
zero_model
,
model
)
...
...
tests/test_moe/test_moe_zero_optim.py
View file @
22c4b88d
...
...
@@ -76,7 +76,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
# check whether parameters are identical in ddp
for
name
,
p
in
zero_model
.
named_parameters
():
if
not
p
.
colo_attr
.
param_is_sharded
and
p
.
colo_attr
.
is_replicated
:
assert_equal_in_group
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
.
to
(
get_current_device
()))
assert_equal_in_group
(
p
.
colo_attr
.
data_
payload
.
to
(
get_current_device
()))
model
=
MoeModel
(
checkpoint
=
True
).
half
()
col_model_deepcopy
(
zero_model
,
model
)
...
...
@@ -100,7 +100,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
for
(
n
,
p
),
zp
in
zip
(
apex_model
.
named_parameters
(),
zero_model
.
parameters
()):
if
'gate'
in
n
:
p
.
data
=
p
.
float
()
p
.
data
.
copy_
(
zp
.
colo_attr
.
sharded_data_tensor
.
payload
)
p
.
data
.
copy_
(
zp
.
colo_attr
.
data_
payload
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
5
:
...
...
tests/test_zero/common.py
View file @
22c4b88d
...
...
@@ -94,7 +94,7 @@ def check_grads_padding(model, zero_model, loose=False):
for
(
name
,
p
),
(
zero_name
,
zero_p
)
in
zip
(
model
.
named_parameters
(),
zero_model
.
named_parameters
()):
# zero_grad = zero_p.grad.clone().to(p.device)
if
zero_p
.
colo_attr
.
is_replicated
:
zero_grad
=
zero_p
.
colo_attr
.
saved_
grad
.
payload
.
clone
().
to
(
p
.
device
)
zero_grad
=
zero_p
.
colo_attr
.
grad
_
payload
.
clone
().
to
(
p
.
device
)
chunks
=
torch
.
flatten
(
p
.
grad
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
continue
...
...
@@ -102,7 +102,7 @@ def check_grads_padding(model, zero_model, loose=False):
if
zero_grad
.
size
(
0
)
>
grad
.
size
(
0
):
zero_grad
=
zero_grad
[:
grad
.
size
(
0
)]
else
:
zero_grad
=
zero_p
.
colo_attr
.
saved_
grad
.
payload
zero_grad
=
zero_p
.
colo_attr
.
grad
_
payload
grad
=
p
.
grad
.
to
(
zero_grad
.
dtype
)
assert
grad
.
dtype
==
zero_grad
.
dtype
...
...
@@ -127,7 +127,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
rank
=
dist
.
get_rank
()
for
(
name
,
p
),
(
zero_name
,
zero_p
)
in
zip
(
model
.
named_parameters
(),
zero_model
.
named_parameters
()):
if
zero_p
.
colo_attr
.
param_is_sharded
:
zero_p
=
zero_p
.
colo_attr
.
sharded_data_tensor
.
payload
.
to
(
p
.
device
).
float
()
zero_p
=
zero_p
.
colo_attr
.
data_
payload
.
to
(
p
.
device
).
float
()
chunks
=
torch
.
flatten
(
p
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
continue
...
...
@@ -135,7 +135,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
if
zero_p
.
size
(
0
)
>
p
.
size
(
0
):
zero_p
=
zero_p
[:
p
.
size
(
0
)]
else
:
zero_p
=
zero_p
.
colo_attr
.
sharded_data_tensor
.
payload
.
to
(
p
.
device
)
zero_p
=
zero_p
.
colo_attr
.
data_
payload
.
to
(
p
.
device
)
assert
p
.
dtype
==
zero_p
.
dtype
assert
allclose
(
p
,
zero_p
,
loose
=
loose
),
f
'
{
p
}
vs
{
zero_p
}
'
tests/test_zero/test_found_inf.py
View file @
22c4b88d
...
...
@@ -55,7 +55,7 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio)
data
,
label
=
data
.
cuda
(),
label
.
cuda
()
_run_step
(
zero_model
,
sharded_optim
,
data
,
label
,
criterion
,
False
)
for
param
in
zero_model
.
parameters
():
assert
not
has_inf_or_nan
(
param
.
colo_attr
.
sharded_data_tensor
.
payload
)
assert
not
has_inf_or_nan
(
param
.
colo_attr
.
data_
payload
)
def
_run_dist
(
rank
,
world_size
,
port
):
...
...
tests/test_zero/test_init_context.py
View file @
22c4b88d
...
...
@@ -46,8 +46,8 @@ def run_model_test(init_device_type, shard_strategy_class):
assert
hasattr
(
param
,
'colo_attr'
)
assert
param
.
colo_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
assert
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
assert
param
.
colo_attr
.
data_
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
colo_attr
.
data_
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
cuda_mem_use
,
_
=
colo_model_mem_usage
(
model
)
model_data_cuda_mem_MB
=
cuda_mem_use
/
1e6
...
...
tests/test_zero/test_shard_param.py
View file @
22c4b88d
...
...
@@ -50,27 +50,27 @@ def _run_shard_param_v2(rank, world_size, port):
param_ref
=
deepcopy
(
param
)
sparam
=
ShardedParamV2
(
param
=
param
)
allclose
(
sparam
.
sharded_data_tensor
.
payload
,
param_ref
.
data
)
allclose
(
sparam
.
data_
payload
,
param_ref
.
data
)
# Test get memory usage
sparam
.
saved_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
,
f
"cpu_mem_use:
{
cpu_mem_use
}
"
sparam
.
remove_torch_payload
()
sparam
.
set_data_none
()
assert
(
param
.
data
.
numel
()
==
0
)
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
# 4 is size of dummy tensor of param.data
assert
cpu_mem_use
==
2
*
3
*
4
*
2
sparam
.
saved_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
sparam
.
remove_torch_payload
()
sparam
.
set_data_none
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
assert
cuda_mem_use
==
0
# append a grad to torch param
param
.
data
=
sparam
.
sharded_data_tensor
.
payload
param
.
data
=
sparam
.
data_
payload
param
.
grad
=
torch
.
randn
(
2
,
3
)
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
2
*
3
*
4
,
f
"cpu_mem_use
{
cpu_mem_use
}
"
...
...
tests/test_zero/test_stateful_tensor_mgr.py
View file @
22c4b88d
...
...
@@ -34,7 +34,7 @@ def run_stm():
colo_set_process_memory_fraction
(
fraction
)
model
=
Net
()
for
p
in
model
.
parameters
():
p
.
colo_attr
=
ShardedParamV2
(
p
,
rm_torch_payload
=
True
)
p
.
colo_attr
=
ShardedParamV2
(
p
,
set_data_none
=
True
)
GLOBAL_MODEL_DATA_TRACER
.
register_model
(
model
)
mem_collector
=
MemStatsCollector
()
stateful_tensor_mgr
=
StatefulTensorMgr
(
mem_collector
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment