Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1b178593
Unverified
Commit
1b178593
authored
Jun 07, 2022
by
ver217
Committed by
GitHub
Jun 07, 2022
Browse files
[tensor] chunk manager monitor mem usage (#1076)
parent
98cdbf49
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
1 deletion
+35
-1
colossalai/tensor/chunk.py
colossalai/tensor/chunk.py
+21
-1
tests/test_tensor/test_chunk.py
tests/test_tensor/test_chunk.py
+14
-0
No files found.
colossalai/tensor/chunk.py
View file @
1b178593
...
@@ -54,6 +54,7 @@ class Chunk:
...
@@ -54,6 +54,7 @@ class Chunk:
if
not
self
.
is_src_rank
:
if
not
self
.
is_src_rank
:
self
.
data
.
storage
().
resize_
(
0
)
self
.
data
.
storage
().
resize_
(
0
)
self
.
tensors_info
:
Dict
[
torch
.
Tensor
,
TensorInfo
]
=
{}
self
.
tensors_info
:
Dict
[
torch
.
Tensor
,
TensorInfo
]
=
{}
self
.
mem
=
self
.
size
*
self
.
data
.
element_size
()
def
append
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
def
append
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
assert
tensor
.
dtype
==
self
.
dtype
assert
tensor
.
dtype
==
self
.
dtype
...
@@ -167,6 +168,10 @@ class Chunk:
...
@@ -167,6 +168,10 @@ class Chunk:
self
.
data
.
copy_
(
dest_chunk
.
data
)
self
.
data
.
copy_
(
dest_chunk
.
data
)
self
.
_update_tensors_ptr
()
self
.
_update_tensors_ptr
()
@
property
def
device_type
(
self
)
->
str
:
return
self
.
data
.
device
.
type
class
ChunkManager
:
class
ChunkManager
:
...
@@ -184,6 +189,7 @@ class ChunkManager:
...
@@ -184,6 +189,7 @@ class ChunkManager:
self
.
lazy_release_tensors
:
List
[
torch
.
Tensor
]
=
[]
self
.
lazy_release_tensors
:
List
[
torch
.
Tensor
]
=
[]
if
enable_distributed_storage
and
chunk_size
is
None
:
if
enable_distributed_storage
and
chunk_size
is
None
:
self
.
rank_load
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
rank_load
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
total_mem
:
Dict
[
str
,
int
]
=
{
'cpu'
:
0
,
'cuda'
:
0
}
def
append_tensor
(
self
,
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
def
append_tensor
(
self
,
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
assert
tensor
not
in
self
.
tensor_chunk_map
assert
tensor
not
in
self
.
tensor_chunk_map
...
@@ -202,6 +208,8 @@ class ChunkManager:
...
@@ -202,6 +208,8 @@ class ChunkManager:
self
.
rank_load
[
group_name
][
src_rank
]
+=
chunk_size
self
.
rank_load
[
group_name
][
src_rank
]
+=
chunk_size
self
.
chunk_groups
[
group_name
].
append
(
chunk
)
self
.
chunk_groups
[
group_name
].
append
(
chunk
)
chunk
.
append
(
tensor
)
chunk
.
append
(
tensor
)
if
not
chunk
.
is_free
:
self
.
total_mem
[
chunk
.
device_type
]
+=
chunk
.
mem
self
.
tensor_chunk_map
[
tensor
]
=
self
.
chunk_groups
[
group_name
][
-
1
]
self
.
tensor_chunk_map
[
tensor
]
=
self
.
chunk_groups
[
group_name
][
-
1
]
if
not
self
.
enable_distributed_storage
:
if
not
self
.
enable_distributed_storage
:
self
.
accessed_chunks
.
add
(
self
.
chunk_groups
[
group_name
][
-
1
])
self
.
accessed_chunks
.
add
(
self
.
chunk_groups
[
group_name
][
-
1
])
...
@@ -222,8 +230,11 @@ class ChunkManager:
...
@@ -222,8 +230,11 @@ class ChunkManager:
chunk
=
self
.
tensor_chunk_map
[
tensor
]
chunk
=
self
.
tensor_chunk_map
[
tensor
]
if
chunk
in
self
.
accessed_chunks
:
if
chunk
in
self
.
accessed_chunks
:
return
return
if
not
chunk
.
is_free
:
self
.
total_mem
[
chunk
.
device_type
]
-=
chunk
.
mem
chunk
.
access
()
chunk
.
access
()
self
.
accessed_chunks
.
add
(
chunk
)
self
.
accessed_chunks
.
add
(
chunk
)
self
.
total_mem
[
chunk
.
device_type
]
+=
chunk
.
mem
def
release_chunk
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
def
release_chunk
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
if
not
self
.
enable_distributed_storage
:
if
not
self
.
enable_distributed_storage
:
...
@@ -234,11 +245,17 @@ class ChunkManager:
...
@@ -234,11 +245,17 @@ class ChunkManager:
if
chunk
.
can_release
:
if
chunk
.
can_release
:
chunk
.
release
()
chunk
.
release
()
self
.
accessed_chunks
.
remove
(
chunk
)
self
.
accessed_chunks
.
remove
(
chunk
)
if
chunk
.
is_free
:
self
.
total_mem
[
chunk
.
device_type
]
-=
chunk
.
mem
def
move_chunk
(
self
,
tensor
:
torch
.
Tensor
,
device
:
torch
.
device
)
->
None
:
def
move_chunk
(
self
,
tensor
:
torch
.
Tensor
,
device
:
torch
.
device
)
->
None
:
chunk
=
self
.
tensor_chunk_map
[
tensor
]
chunk
=
self
.
tensor_chunk_map
[
tensor
]
if
chunk
.
can_move_device
:
if
chunk
.
data
.
device
==
device
:
return
if
chunk
.
can_move_device
and
not
chunk
.
is_free
:
self
.
total_mem
[
chunk
.
device_type
]
-=
chunk
.
mem
chunk
.
move_device
(
device
)
chunk
.
move_device
(
device
)
self
.
total_mem
[
chunk
.
device_type
]
+=
chunk
.
mem
def
trans_tensor_state
(
self
,
tensor
:
torch
.
Tensor
,
state
:
TensorState
)
->
None
:
def
trans_tensor_state
(
self
,
tensor
:
torch
.
Tensor
,
state
:
TensorState
)
->
None
:
chunk
=
self
.
tensor_chunk_map
[
tensor
]
chunk
=
self
.
tensor_chunk_map
[
tensor
]
...
@@ -248,7 +265,9 @@ class ChunkManager:
...
@@ -248,7 +265,9 @@ class ChunkManager:
chunk
=
self
.
tensor_chunk_map
[
tensor
]
chunk
=
self
.
tensor_chunk_map
[
tensor
]
if
not
chunk
.
can_reduce
:
if
not
chunk
.
can_reduce
:
return
False
return
False
self
.
total_mem
[
chunk
.
device_type
]
-=
chunk
.
mem
chunk
.
reduce
(
is_all_reduce
=
not
self
.
enable_distributed_storage
)
chunk
.
reduce
(
is_all_reduce
=
not
self
.
enable_distributed_storage
)
self
.
total_mem
[
chunk
.
device_type
]
+=
chunk
.
mem
return
True
return
True
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data
:
torch
.
Tensor
)
->
None
:
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data
:
torch
.
Tensor
)
->
None
:
...
@@ -272,6 +291,7 @@ class ChunkManager:
...
@@ -272,6 +291,7 @@ class ChunkManager:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
msg
=
f
'Rank
{
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
}
:
\n
'
msg
=
f
'Rank
{
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
}
:
\n
'
msg
+=
'Total memory: '
+
', '
.
join
([
f
'
{
k
}
=
{
v
}
B'
for
k
,
v
in
self
.
total_mem
.
items
()])
+
'
\n
'
for
group_name
,
group
in
self
.
chunk_groups
.
items
():
for
group_name
,
group
in
self
.
chunk_groups
.
items
():
msg
+=
f
'Group
{
group_name
}
:
\n
'
msg
+=
f
'Group
{
group_name
}
:
\n
'
for
i
,
chunk
in
enumerate
(
group
):
for
i
,
chunk
in
enumerate
(
group
):
...
...
tests/test_tensor/test_chunk.py
View file @
1b178593
...
@@ -32,6 +32,8 @@ HAS_TENSORS = {
...
@@ -32,6 +32,8 @@ HAS_TENSORS = {
}
}
}
}
TOTAL_MEM
=
{
True
:
{
True
:
[
8192
,
8192
],
False
:
[
16384
,
16384
]},
False
:
{
True
:
[
8192
,
4096
],
False
:
[
12288
,
12288
]}}
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
...
@@ -42,15 +44,27 @@ def run_chunk_zero(use_chunk, use_zero):
...
@@ -42,15 +44,27 @@ def run_chunk_zero(use_chunk, use_zero):
params
=
[
torch
.
rand
(
32
,
32
)
for
_
in
range
(
3
)]
params
=
[
torch
.
rand
(
32
,
32
)
for
_
in
range
(
3
)]
chunk_size
=
2048
if
use_chunk
else
None
chunk_size
=
2048
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
enable_distributed_storage
=
use_zero
)
chunk_manager
=
ChunkManager
(
chunk_size
,
enable_distributed_storage
=
use_zero
)
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
0
for
p
in
params
:
for
p
in
params
:
chunk_manager
.
append_tensor
(
p
,
'param'
)
chunk_manager
.
append_tensor
(
p
,
'param'
)
check_has_params
(
params
,
HAS_TENSORS
[
use_chunk
][
use_zero
][
rank
])
check_has_params
(
params
,
HAS_TENSORS
[
use_chunk
][
use_zero
][
rank
])
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
TOTAL_MEM
[
use_chunk
][
use_zero
][
rank
]
for
p
in
params
:
for
p
in
params
:
chunk_manager
.
access_chunk
(
p
)
chunk_manager
.
access_chunk
(
p
)
check_has_params
(
params
,
[
True
,
True
,
True
])
check_has_params
(
params
,
[
True
,
True
,
True
])
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
TOTAL_MEM
[
use_chunk
][
False
][
rank
]
for
p
in
params
:
for
p
in
params
:
chunk_manager
.
release_chunk
(
p
)
chunk_manager
.
release_chunk
(
p
)
check_has_params
(
params
,
HAS_TENSORS
[
use_chunk
][
use_zero
][
rank
])
check_has_params
(
params
,
HAS_TENSORS
[
use_chunk
][
use_zero
][
rank
])
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
TOTAL_MEM
[
use_chunk
][
use_zero
][
rank
],
chunk_manager
.
total_mem
[
'cuda'
]
for
p
in
params
:
chunk_manager
.
move_chunk
(
p
,
torch
.
device
(
'cpu'
))
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
TOTAL_MEM
[
use_chunk
][
use_zero
][
rank
],
chunk_manager
.
total_mem
[
'cuda'
]
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
0
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment