Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
37cb70fe
"...git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "5f1af25a29d5918e8ba27ad65f87cca63a215799"
Unverified
Commit
37cb70fe
authored
Mar 28, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 28, 2022
Browse files
[zero] get memory usage for sharded param (#536)
parent
56ad9457
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
2 deletions
+45
-2
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+26
-1
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+19
-1
No files found.
colossalai/zero/sharded_param/sharded_param.py
View file @
37cb70fe
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.zero.sharded_param
import
ShardedTensor
from
colossalai.zero.sharded_param
import
ShardedTensor
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
class
ShardedParamV2
(
object
):
class
ShardedParamV2
(
object
):
...
@@ -40,3 +40,28 @@ class ShardedParamV2(object):
...
@@ -40,3 +40,28 @@ class ShardedParamV2(object):
@
property
@
property
def
param_is_sharded
(
self
):
def
param_is_sharded
(
self
):
return
self
.
_sharded_data_tensor
.
is_sharded
return
self
.
_sharded_data_tensor
.
is_sharded
def
get_memory_usage
(
self
)
->
Tuple
[
int
,
int
]:
"""
get the memory usage of the param, including data and grad
Returns:
Tuple[int, int]: cuda mem usage in Byte, cpu memory usage in Byte
"""
cuda_mem_use
,
cpu_mem_use
=
0
,
0
def
_update_mem_use
(
t
:
Optional
[
torch
.
Tensor
]):
if
t
is
None
:
return
assert
isinstance
(
t
,
torch
.
Tensor
)
nonlocal
cuda_mem_use
nonlocal
cpu_mem_use
if
t
.
device
.
type
==
'cpu'
:
cpu_mem_use
+=
t
.
numel
()
*
t
.
element_size
()
elif
t
.
device
.
type
==
'cuda'
:
cuda_mem_use
+=
t
.
numel
()
*
t
.
element_size
()
_update_mem_use
(
self
.
sharded_data_tensor
.
payload
)
_update_mem_use
(
self
.
fp16_grad
)
_update_mem_use
(
self
.
fp32_grad
)
return
cuda_mem_use
,
cpu_mem_use
tests/test_zero_data_parallel/test_shard_param.py
View file @
37cb70fe
...
@@ -54,6 +54,24 @@ def _run_shard_param_v2(rank, world_size, port):
...
@@ -54,6 +54,24 @@ def _run_shard_param_v2(rank, world_size, port):
sparam
.
remove_torch_payload
()
sparam
.
remove_torch_payload
()
assert
(
param
.
data
.
numel
()
==
1
)
assert
(
param
.
data
.
numel
()
==
1
)
# Test get memory usage
sparam
.
fp32_grad
=
torch
.
randn
(
2
,
3
)
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
sparam
.
fp16_grad
=
torch
.
randn
(
2
,
3
).
cuda
().
half
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
assert
cuda_mem_use
==
2
*
3
*
2
sparam
.
fp16_grad
=
None
sparam
.
fp32_grad
=
torch
.
randn
(
2
,
3
)
sparam
.
remove_torch_payload
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
assert
cuda_mem_use
==
0
print
(
f
'cuda_mem_use
{
cuda_mem_use
}
cpu_mem_use
{
cpu_mem_use
}
'
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
...
@@ -64,5 +82,5 @@ def test_shard_param_v2(world_size):
...
@@ -64,5 +82,5 @@ def test_shard_param_v2(world_size):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_shard_tensor
(
2
)
#
test_shard_tensor(2)
test_shard_param_v2
(
2
)
test_shard_param_v2
(
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment