Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7c6c427d
Unverified
Commit
7c6c427d
authored
Mar 31, 2022
by
ver217
Committed by
GitHub
Mar 31, 2022
Browse files
[zero] trace states of fp16/32 grad and fp32 param (#571)
parent
7675366f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
69 additions
and
73 deletions
+69
-73
colossalai/utils/memory_utils/utils.py
colossalai/utils/memory_utils/utils.py
+3
-3
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+14
-23
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+52
-36
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+0
-5
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+0
-6
No files found.
colossalai/utils/memory_utils/utils.py
View file @
7c6c427d
...
@@ -51,9 +51,9 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
...
@@ -51,9 +51,9 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
"""
"""
A colossal API for model data tensor move.
A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU.
The src and target tensors could be resident on both CPU and GPU.
NOTE() The source tensor payload will be removed after this function.
NOTE() The source tensor payload will be removed after this function.
The function will record the communication volume between CPU and GPU.
The function will record the communication volume between CPU and GPU.
Args:
Args:
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
...
@@ -93,7 +93,7 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
...
@@ -93,7 +93,7 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
raise
TypeError
(
'colo_model_data_move_to_cpu dose not accept type {type(t)}'
)
raise
TypeError
(
'colo_model_data_move_to_cpu dose not accept type {type(t)}'
)
if
isinstance
(
target_device
,
int
):
if
isinstance
(
target_device
,
int
):
target_device
=
torch
.
cuda
(
f
'device"
{
target_device
}
'
)
target_device
=
torch
.
device
(
f
'cuda:
{
target_device
}
'
)
# deal with torch.device('cpu') and torch.device('cpu:0)
# deal with torch.device('cpu') and torch.device('cpu:0)
if
t_payload
.
device
.
type
==
target_device
.
type
:
if
t_payload
.
device
.
type
==
target_device
.
type
:
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
7c6c427d
...
@@ -18,7 +18,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
...
@@ -18,7 +18,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
from
colossalai.utils.memory_utils.utils
import
(
colo_cuda_memory_capacity
,
colo_model_data_move_to_cpu
)
from
colossalai.utils.memory_utils.utils
import
(
colo_cuda_memory_capacity
,
colo_model_data_move_to_cpu
)
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_param.tensorful_state
import
(
StatefulTensor
,
TensorState
)
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -245,27 +245,7 @@ class ShardedModelV2(nn.Module):
...
@@ -245,27 +245,7 @@ class ShardedModelV2(nn.Module):
# We also allows to interleave no-sync pass with sync passes, if desired.
# We also allows to interleave no-sync pass with sync passes, if desired.
if
not
self
.
_require_backward_grad_sync
:
if
not
self
.
_require_backward_grad_sync
:
continue
continue
# Reduced grad is saved in `p.colo_attr.saved_grad`
# It can be on CPU or CUDA
# It can be fp16 or fp32
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
if
self
.
reuse_fp16_shard
:
grad_fp16_payload
=
p
.
colo_attr
.
sharded_data_tensor
.
payload
else
:
grad_fp16_payload
=
cast_tensor_to_fp32
(
p
.
colo_attr
.
fp16_grad
.
payload
)
assert
isinstance
(
grad_fp16_payload
,
torch
.
Tensor
)
if
p
.
colo_attr
.
offload_grad
:
colo_model_data_move_to_cpu
(
grad_fp16_payload
)
if
not
p
.
colo_attr
.
saved_grad
.
is_null
():
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
# Accumulate grad, saved grad must be fp32
p
.
colo_attr
.
saved_grad
.
reset_payload
(
cast_tensor_to_fp32
(
p
.
colo_attr
.
saved_grad
.
payload
))
p
.
colo_attr
.
saved_grad
.
payload
.
add_
(
grad_fp16_payload
.
view_as
(
p
.
colo_attr
.
saved_grad
.
payload
))
else
:
p
.
colo_attr
.
saved_grad
.
reset_payload
(
grad_fp16_payload
)
p
.
grad
=
None
p
.
grad
=
None
p
.
colo_attr
.
fp16_grad
.
set_null
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
@@ -322,11 +302,22 @@ class ShardedModelV2(nn.Module):
...
@@ -322,11 +302,22 @@ class ShardedModelV2(nn.Module):
if
self
.
gradient_postdivide_factor
>
1
:
if
self
.
gradient_postdivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
# FIXME(ver217): remove the below line when impl eviction policy
if
param
.
colo_attr
.
offload_grad
:
colo_model_data_move_to_cpu
(
reduced_grad
)
if
self
.
reuse_fp16_shard
:
if
self
.
reuse_fp16_shard
:
param
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
reduced_grad
.
data
)
assert
param
.
colo_attr
.
saved_grad
.
is_null
(
),
'Gradien accumulation is not supported when reuse_fp16_shard=True'
param
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
reduced_grad
)
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
=
True
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
=
True
param
.
colo_attr
.
saved_grad
.
reset_payload
(
param
.
colo_attr
.
sharded_data_tensor
.
payload
)
else
:
else
:
param
.
colo_attr
.
fp16_grad
=
StatefulTensor
(
reduced_grad
.
data
)
reduced_grad
=
cast_tensor_to_fp32
(
reduced_grad
)
if
param
.
colo_attr
.
saved_grad
.
is_null
():
param
.
colo_attr
.
saved_grad
.
reset_payload
(
reduced_grad
)
else
:
param
.
colo_attr
.
saved_grad
.
payload
.
add_
(
reduced_grad
.
view_as
(
param
.
colo_attr
.
saved_grad
.
payload
))
param
.
colo_attr
.
saved_grad
.
trans_state
(
TensorState
.
HOLD
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
for
p
in
self
.
module
.
parameters
()],
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
for
p
in
self
.
module
.
parameters
()],
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
7c6c427d
...
@@ -12,11 +12,12 @@ from colossalai.logging import get_dist_logger
...
@@ -12,11 +12,12 @@ from colossalai.logging import get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
(
colo_model_data_tensor_move
,
colo_model_tensor_clone
,
from
colossalai.utils.memory_utils.utils
import
(
colo_model_data_tensor_move
_inline
,
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
colo_tensor_mem_usage
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.zero.sharded_param.tensorful_state
import
(
StatefulTensor
,
TensorState
)
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -112,7 +113,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -112,7 +113,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
_logger
=
get_dist_logger
(
"ShardedOptimizerV2"
)
self
.
_logger
=
get_dist_logger
(
"ShardedOptimizerV2"
)
# Store fp32 param shards
# Store fp32 param shards
self
.
master_params
:
Dict
[
Parameter
,
Tensor
]
=
{}
self
.
master_params
:
Dict
[
Parameter
,
Stateful
Tensor
]
=
{}
for
group
in
self
.
optim
.
param_groups
:
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
...
@@ -123,7 +124,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -123,7 +124,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Param is no sharded, which means we use ZeRO-2 here
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
# As we only store param shard, we shard it here
self
.
shard_strategy
.
shard
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
shard
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
master_params
[
p
]
=
cast_tensor_to_fp32
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
).
to
(
self
.
device
)
self
.
master_params
[
p
]
=
StatefulTensor
(
cast_tensor_to_fp32
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
).
to
(
self
.
device
))
if
not
is_param_sharded
:
if
not
is_param_sharded
:
# In this branch, there's no need to shard param
# In this branch, there's no need to shard param
# So we gather here
# So we gather here
...
@@ -184,13 +186,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -184,13 +186,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
zero_grad
()
self
.
zero_grad
()
return
return
# assign master param pointers to p.data.
self
.
_prepare_data
()
# We will not trigger data copy here.
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
data
=
self
.
master_params
[
p
]
# Now p.data is sharded
# So optimizer states are sharded naturally
self
.
_logger
.
debug
(
self
.
_logger
.
debug
(
f
"Before step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
f
"Before step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
...
@@ -201,30 +197,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -201,30 +197,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
_logger
.
debug
(
self
.
_logger
.
debug
(
f
"After step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
f
"After step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
ranks
=
[
0
])
# Copy master param data (fp32) to payload of colo_attr (fp16)
self
.
_write_back_data
()
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
is_param_sharded
=
p
.
colo_attr
.
sharded_data_tensor
.
is_sharded
if
not
is_param_sharded
:
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self
.
shard_strategy
.
shard
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
colo_model_tensor_clone
(
p
.
half
(),
torch
.
cuda
.
current_device
()))
if
not
is_param_sharded
:
# We gather full fp16 param here
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
p
.
data
=
p
.
colo_attr
.
sharded_data_tensor
.
payload
return
ret
return
ret
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
...
@@ -276,6 +249,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -276,6 +249,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Because we will judge whether local grad accumulation
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
# is enabled by wheter grad is None
self
.
optim
.
zero_grad
(
set_to_none
=
True
)
self
.
optim
.
zero_grad
(
set_to_none
=
True
)
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
colo_attr
.
saved_grad
.
set_null
()
def
sync_grad
(
self
):
def
sync_grad
(
self
):
pass
pass
...
@@ -288,9 +264,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -288,9 +264,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
fp32_shards_used_cuda_margin_mem
=
0
fp32_shards_used_cuda_margin_mem
=
0
for
group
in
self
.
optim
.
param_groups
:
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
shard_mem
=
self
.
master_params
[
p
].
numel
()
*
self
.
master_params
[
p
].
element_size
()
shard_mem
=
self
.
master_params
[
p
].
payload
.
numel
()
*
self
.
master_params
[
p
].
payload
.
element_size
()
if
fp32_shards_used_cuda_margin_mem
+
shard_mem
<
fp32_shards_available_cuda_margin_mem
:
if
fp32_shards_used_cuda_margin_mem
+
shard_mem
<
fp32_shards_available_cuda_margin_mem
:
self
.
master_params
[
p
]
=
self
.
master_params
[
p
]
.
to
(
torch
.
cuda
.
current_device
())
colo_model_data_tensor_move_inline
(
self
.
master_params
[
p
]
,
torch
.
cuda
.
current_device
())
p
.
grad
.
data
=
p
.
grad
.
data
.
to
(
torch
.
cuda
.
current_device
())
p
.
grad
.
data
=
p
.
grad
.
data
.
to
(
torch
.
cuda
.
current_device
())
p
.
colo_attr
.
offload_grad
=
False
p
.
colo_attr
.
offload_grad
=
False
fp32_shards_used_cuda_margin_mem
+=
shard_mem
fp32_shards_used_cuda_margin_mem
+=
shard_mem
...
@@ -298,10 +274,50 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -298,10 +274,50 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def
_prepare_grads
(
self
):
def
_prepare_grads
(
self
):
for
group
in
self
.
optim
.
param_groups
:
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
p
.
colo_attr
.
saved_grad
.
trans_state
(
TensorState
.
COMPUTE
)
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# If we change p.grad directly
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.colo_attr.saved_grad.payload here
# We just set p.data = p.colo_attr.saved_grad.payload here
p
.
data
=
p
.
colo_attr
.
saved_grad
.
payload
p
.
data
=
p
.
colo_attr
.
saved_grad
.
payload
p
.
grad
=
p
.
colo_attr
.
saved_grad
.
payload
p
.
grad
=
p
.
colo_attr
.
saved_grad
.
payload
# Set p.data to empty tensor, in case of memory leaking
p
.
colo_attr
.
remove_torch_payload
()
def
_prepare_data
(
self
):
# assign master param pointers to p.data.
# We will not trigger data copy here.
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
self
.
master_params
[
p
].
trans_state
(
TensorState
.
COMPUTE
)
p
.
data
=
self
.
master_params
[
p
].
payload
# Now p.data is sharded
# So optimizer states are sharded naturally
def
_write_back_data
(
self
):
# Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
is_param_sharded
=
p
.
colo_attr
.
sharded_data_tensor
.
is_sharded
if
not
is_param_sharded
:
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self
.
shard_strategy
.
shard
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
colo_model_tensor_clone
(
p
.
half
(),
torch
.
cuda
.
current_device
()))
if
not
is_param_sharded
:
# We gather full fp16 param here
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
p
.
data
=
p
.
colo_attr
.
sharded_data_tensor
.
payload
self
.
master_params
[
p
].
trans_state
(
TensorState
.
HOLD
)
p
.
colo_attr
.
saved_grad
.
set_null
()
p
.
colo_attr
.
saved_grad
.
set_null
()
colossalai/zero/sharded_param/sharded_param.py
View file @
7c6c427d
...
@@ -10,7 +10,6 @@ class ShardedParamV2(object):
...
@@ -10,7 +10,6 @@ class ShardedParamV2(object):
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
rm_torch_payload
=
False
)
->
None
:
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
rm_torch_payload
=
False
)
->
None
:
self
.
_sharded_data_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
)
self
.
_sharded_data_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
)
self
.
fp16_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
saved_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
self
.
saved_grad
:
StatefulTensor
=
StatefulTensor
(
None
,
TensorState
.
FREE
)
# This attribute must be initialized in ShardedModel
# This attribute must be initialized in ShardedModel
self
.
offload_grad
:
bool
=
False
self
.
offload_grad
:
bool
=
False
...
@@ -57,10 +56,6 @@ class ShardedParamV2(object):
...
@@ -57,10 +56,6 @@ class ShardedParamV2(object):
_update_mem_use
(
self
.
sharded_data_tensor
.
payload
)
_update_mem_use
(
self
.
sharded_data_tensor
.
payload
)
address_set
.
add
(
self
.
sharded_data_tensor
.
payload
.
data_ptr
())
address_set
.
add
(
self
.
sharded_data_tensor
.
payload
.
data_ptr
())
if
not
self
.
fp16_grad
.
is_null
()
and
self
.
fp16_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
fp16_grad
.
payload
)
address_set
.
add
(
self
.
fp16_grad
.
data_ptr
())
if
not
self
.
saved_grad
.
is_null
()
and
self
.
saved_grad
.
data_ptr
()
not
in
address_set
:
if
not
self
.
saved_grad
.
is_null
()
and
self
.
saved_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
saved_grad
.
payload
)
_update_mem_use
(
self
.
saved_grad
.
payload
)
address_set
.
add
(
self
.
saved_grad
.
data_ptr
())
address_set
.
add
(
self
.
saved_grad
.
data_ptr
())
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
7c6c427d
...
@@ -63,12 +63,6 @@ def _run_shard_param_v2(rank, world_size, port):
...
@@ -63,12 +63,6 @@ def _run_shard_param_v2(rank, world_size, port):
# 4 is size of dummy tensor of param.data
# 4 is size of dummy tensor of param.data
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
sparam
.
fp16_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
).
cuda
().
half
())
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
assert
cuda_mem_use
==
2
*
3
*
2
sparam
.
fp16_grad
=
StatefulTensor
(
None
)
sparam
.
saved_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
sparam
.
saved_grad
=
StatefulTensor
(
torch
.
randn
(
2
,
3
))
sparam
.
remove_torch_payload
()
sparam
.
remove_torch_payload
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment