Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a590ed0b
Unverified
Commit
a590ed0b
authored
Mar 28, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 28, 2022
Browse files
[zero] improve the accuracy of get_memory_usage of sharded param (#538)
parent
37cb70fe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
8 deletions
+39
-8
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+18
-2
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+21
-6
No files found.
colossalai/zero/sharded_param/sharded_param.py
View file @
a590ed0b
...
@@ -60,8 +60,24 @@ class ShardedParamV2(object):
...
@@ -60,8 +60,24 @@ class ShardedParamV2(object):
elif
t
.
device
.
type
==
'cuda'
:
elif
t
.
device
.
type
==
'cuda'
:
cuda_mem_use
+=
t
.
numel
()
*
t
.
element_size
()
cuda_mem_use
+=
t
.
numel
()
*
t
.
element_size
()
address_set
=
set
()
_update_mem_use
(
self
.
sharded_data_tensor
.
payload
)
_update_mem_use
(
self
.
sharded_data_tensor
.
payload
)
_update_mem_use
(
self
.
fp16_grad
)
address_set
.
add
(
self
.
sharded_data_tensor
.
payload
.
data_ptr
())
_update_mem_use
(
self
.
fp32_grad
)
if
self
.
fp16_grad
is
not
None
and
self
.
fp16_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
fp16_grad
)
address_set
.
add
(
self
.
fp16_grad
.
data_ptr
())
if
self
.
fp32_grad
is
not
None
and
self
.
fp32_grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
fp32_grad
)
address_set
.
add
(
self
.
fp32_grad
.
data_ptr
())
if
self
.
param
.
data
is
not
None
and
self
.
param
.
data
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
param
.
data
)
address_set
.
add
(
self
.
param
.
data
.
data_ptr
())
if
self
.
param
.
grad
is
not
None
and
self
.
param
.
grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
param
.
grad
)
address_set
.
add
(
self
.
param
.
grad
.
data_ptr
())
return
cuda_mem_use
,
cpu_mem_use
return
cuda_mem_use
,
cpu_mem_use
tests/test_zero_data_parallel/test_shard_param.py
View file @
a590ed0b
...
@@ -51,26 +51,41 @@ def _run_shard_param_v2(rank, world_size, port):
...
@@ -51,26 +51,41 @@ def _run_shard_param_v2(rank, world_size, port):
allclose
(
sparam
.
sharded_data_tensor
.
payload
,
param_ref
.
data
)
allclose
(
sparam
.
sharded_data_tensor
.
payload
,
param_ref
.
data
)
sparam
.
remove_torch_payload
()
assert
(
param
.
data
.
numel
()
==
1
)
# Test get memory usage
# Test get memory usage
sparam
.
fp32_grad
=
torch
.
randn
(
2
,
3
)
sparam
.
fp32_grad
=
torch
.
randn
(
2
,
3
)
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
assert
cpu_mem_use
==
2
*
3
*
4
*
2
,
f
"cpu_mem_use:
{
cpu_mem_use
}
"
sparam
.
remove_torch_payload
()
assert
(
param
.
data
.
numel
()
==
1
)
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
# 4 is size of dummy tensor of param.data
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
sparam
.
fp16_grad
=
torch
.
randn
(
2
,
3
).
cuda
().
half
()
sparam
.
fp16_grad
=
torch
.
randn
(
2
,
3
).
cuda
().
half
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
assert
cuda_mem_use
==
2
*
3
*
2
assert
cuda_mem_use
==
2
*
3
*
2
sparam
.
fp16_grad
=
None
sparam
.
fp16_grad
=
None
sparam
.
fp32_grad
=
torch
.
randn
(
2
,
3
)
sparam
.
fp32_grad
=
torch
.
randn
(
2
,
3
)
sparam
.
remove_torch_payload
()
sparam
.
remove_torch_payload
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
4
assert
cuda_mem_use
==
0
# append a grad to torch param
param
.
data
=
sparam
.
sharded_data_tensor
.
payload
param
.
grad
=
torch
.
randn
(
2
,
3
)
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
+
2
*
3
*
4
,
f
"cpu_mem_use
{
cpu_mem_use
}
"
assert
cuda_mem_use
==
0
# reuse torch grad for sparam
sparam
.
fp32_grad
=
param
.
grad
cuda_mem_use
,
cpu_mem_use
=
sparam
.
get_memory_usage
()
assert
cpu_mem_use
==
2
*
3
*
4
*
2
assert
cpu_mem_use
==
2
*
3
*
4
*
2
assert
cuda_mem_use
==
0
assert
cuda_mem_use
==
0
print
(
f
'cuda_mem_use
{
cuda_mem_use
}
cpu_mem_use
{
cpu_mem_use
}
'
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment