Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b8034016
Unverified
Commit
b8034016
authored
Aug 11, 2022
by
HELSON
Committed by
GitHub
Aug 11, 2022
Browse files
[zero] add chunk_managerV2 for all-gather chunk (#1441)
parent
3b26516c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
298 additions
and
0 deletions
+298
-0
colossalai/gemini/update/__init__.py
colossalai/gemini/update/__init__.py
+1
-0
colossalai/gemini/update/chunk_mgrv2.py
colossalai/gemini/update/chunk_mgrv2.py
+221
-0
tests/test_gemini/update/test_chunk_mgrv2.py
tests/test_gemini/update/test_chunk_mgrv2.py
+76
-0
No files found.
colossalai/gemini/update/__init__.py
View file @
b8034016
from
.chunkv2
import
ChunkV2
from
.chunkv2
import
ChunkV2
from
.chunk_mgrv2
import
ChunkManagerV2
from
.search_utils
import
clasify_params
,
search_chunk_configuration
from
.search_utils
import
clasify_params
,
search_chunk_configuration
colossalai/gemini/update/chunk_mgrv2.py
0 → 100644
View file @
b8034016
import
torch
from
typing
import
Optional
,
Dict
,
Deque
,
Set
,
List
,
Tuple
,
Iterable
from
collections
import
deque
from
colossalai.utils
import
get_current_device
from
colossalai.tensor
import
ColoTensor
from
colossalai.gemini.chunk
import
ChunkFullError
,
TensorState
from
colossalai.gemini.update
import
ChunkV2
as
Chunk
class
ChunkManagerV2
:
"""
A manager class to manipulate the tensors in chunks.
Args:
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
pin_memory (bool): if ture, all chunks have a piece of pinned memory in CPU.
"""
def
__init__
(
self
,
chunk_configuration
:
Dict
[
int
,
Dict
],
init_device
:
Optional
[
torch
.
device
]
=
None
,
pin_memory
:
bool
=
False
)
->
None
:
self
.
device
=
init_device
or
get_current_device
()
self
.
size_config
:
Dict
[
int
,
int
]
=
dict
()
self
.
kwargs_config
=
chunk_configuration
for
k
,
v
in
self
.
kwargs_config
.
items
():
self
.
size_config
[
k
]
=
v
.
pop
(
'chunk_size'
)
v
[
'init_device'
]
=
self
.
device
v
[
'pin_memory'
]
=
pin_memory
self
.
chunk_groups
:
Dict
[
str
,
Deque
]
=
dict
()
self
.
tensor_chunk_map
:
Dict
[
torch
.
Tensor
,
Chunk
]
=
dict
()
self
.
accessed_chunks
:
Set
[
Chunk
]
=
set
()
self
.
lazy_release_tensors
:
List
[
torch
.
Tensor
]
=
list
()
self
.
total_mem
:
Dict
[
str
,
int
]
=
{
'cpu'
:
0
,
'cuda'
:
0
}
def
append_tensor
(
self
,
tensor
:
ColoTensor
,
group_type
:
str
,
config_key
:
int
)
->
None
:
"""Append a tensor to a chunk.
"""
assert
tensor
not
in
self
.
tensor_chunk_map
assert
isinstance
(
tensor
,
ColoTensor
),
"Please feed ColoTensor to this ChunkManager"
assert
config_key
in
self
.
size_config
chunk_size
=
self
.
size_config
[
config_key
]
chunk_kwargs
=
self
.
kwargs_config
[
config_key
]
group_name
=
"{}_{}"
.
format
(
group_type
,
config_key
)
chunk_group
=
self
.
__get_chunk_group
(
group_name
)
try
:
# append the tensor to the last chunk
chunk_group
[
-
1
].
append_tensor
(
tensor
)
except
(
IndexError
,
ChunkFullError
):
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
if
chunk_group
:
# the chunk group is not empty
# close the last chunk
self
.
__close_one_chunk
(
chunk_group
[
-
1
])
if
tensor
.
numel
()
>
chunk_size
:
chunk_size
=
tensor
.
numel
()
chunk
=
Chunk
(
chunk_size
=
chunk_size
,
process_group
=
tensor
.
process_group
,
dtype
=
tensor
.
dtype
,
**
chunk_kwargs
)
chunk_group
.
append
(
chunk
)
chunk
.
append_tensor
(
tensor
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
self
.
tensor_chunk_map
[
tensor
]
=
chunk_group
[
-
1
]
def
close_all_groups
(
self
):
"""Close all the chunks of all groups.
"""
for
group_name
in
self
.
chunk_groups
:
self
.
__close_one_chunk
(
self
.
chunk_groups
[
group_name
][
-
1
])
def
access_chunk
(
self
,
chunk
:
Chunk
)
->
None
:
"""Make the chunk can be used for calculation.
"""
if
chunk
in
self
.
accessed_chunks
:
return
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
chunk
.
access_chunk
()
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
self
.
accessed_chunks
.
add
(
chunk
)
def
release_chunk
(
self
,
chunk
:
Chunk
)
->
None
:
"""Scatter the chunk in CUDA.
"""
if
chunk
not
in
self
.
accessed_chunks
:
return
if
chunk
.
can_release
:
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
chunk
.
release_chunk
()
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
self
.
accessed_chunks
.
remove
(
chunk
)
def
move_chunk
(
self
,
chunk
:
Chunk
,
device
:
torch
.
device
)
->
None
:
"""Move the shard of the chunk to the target device.
"""
if
not
chunk
.
can_move
or
chunk
.
device_type
==
device
.
type
:
return
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
chunk
.
shard_move
(
device
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
def
trans_tensor_state
(
self
,
tensor
:
torch
.
Tensor
,
state
:
TensorState
)
->
None
:
"""Transit tensor state according to pre-defined state machine.
"""
chunk
=
self
.
tensor_chunk_map
[
tensor
]
chunk
.
tensor_trans_state
(
tensor
,
state
)
def
reduce_chunk
(
self
,
chunk
:
Chunk
)
->
bool
:
"""Reduce or all reduce the chunk.
"""
if
not
chunk
.
can_reduce
:
return
False
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
chunk
.
release_chunk
()
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
return
True
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data
:
torch
.
Tensor
)
->
None
:
"""
Copy data to the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data (torch.Tensor): the tensor to be copied to the chunk
"""
chunk
=
self
.
tensor_chunk_map
[
tensor
]
chunk
.
copy_tensor_to_chunk_slice
(
tensor
,
data
)
def
get_chunk
(
self
,
tensor
:
torch
.
Tensor
)
->
Chunk
:
"""
Return the chunk owning the tensor.
Args:
tensor (torch.Tensor): a torch tensor object
"""
return
self
.
tensor_chunk_map
[
tensor
]
def
add_lazy_release_tensors
(
self
,
tensors
:
List
[
torch
.
Tensor
])
->
None
:
"""
Add tensors to the buffer for lazy release.
Args:
tensors (List[torch.Tensor]): the tensors to be released lazily
"""
self
.
lazy_release_tensors
.
extend
(
tensors
)
def
exec_lazy_release
(
self
)
->
None
:
"""
Execute release for tensors added to the lazy release buffer.
"""
for
chunk
in
self
.
get_chunks
(
self
.
lazy_release_tensors
):
self
.
release_chunk
(
chunk
)
self
.
lazy_release_tensors
.
clear
()
def
__repr__
(
self
)
->
str
:
msg
=
[
'Chunk Manager Information:
\n
'
,
'Total memory: '
+
', '
.
join
([
f
'
{
k
}
=
{
v
}
B'
for
k
,
v
in
self
.
total_mem
.
items
()])
+
'
\n
'
]
for
group_name
,
group
in
self
.
chunk_groups
.
items
():
msg
.
append
(
f
'Group
{
group_name
}
:
\n
'
)
for
i
,
chunk
in
enumerate
(
group
):
msg
.
append
(
f
'[
{
i
}
]
{
chunk
}
\n
'
)
return
''
.
join
(
msg
)
def
get_chunks
(
self
,
tensors
:
Iterable
[
torch
.
Tensor
])
->
Tuple
[
Chunk
,
...]:
"""
Get all chunks owning the input tensors.
Args:
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
"""
chunks
=
[]
for
tensor
in
tensors
:
chunk
=
self
.
get_chunk
(
tensor
)
if
chunk
not
in
chunks
:
chunks
.
append
(
chunk
)
return
tuple
(
chunks
)
def
add_extern_static_tensor
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
"""Add extern static tensor to chunk manager.
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
They are "static", which means their shape, dtype, device never change.
Thus, their memory usage never changes.
Args:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert
tensor
not
in
self
.
tensor_chunk_map
self
.
total_mem
[
tensor
.
device
.
type
]
+=
tensor
.
numel
()
*
tensor
.
element_size
()
def
__get_chunk_group
(
self
,
group_name
:
str
)
->
Deque
:
"""Register a chunk group.
"""
if
group_name
not
in
self
.
chunk_groups
:
self
.
chunk_groups
[
group_name
]
=
deque
()
return
self
.
chunk_groups
[
group_name
]
def
__close_one_chunk
(
self
,
chunk
:
Chunk
):
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
chunk
.
close_chunk
(
self
.
device
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
def
__sub_memroy_usage
(
self
,
usage
:
Dict
[
str
,
int
]):
for
k
,
v
in
usage
.
items
():
self
.
total_mem
[
k
]
-=
v
def
__add_memory_usage
(
self
,
usage
:
Dict
[
str
,
int
]):
for
k
,
v
in
usage
.
items
():
self
.
total_mem
[
k
]
+=
v
tests/test_gemini/update/test_chunk_mgrv2.py
0 → 100644
View file @
b8034016
import
torch
import
colossalai
import
pytest
import
torch.multiprocessing
as
mp
from
functools
import
partial
from
colossalai.gemini.update
import
ChunkManagerV2
from
colossalai.testing
import
rerun_if_address_is_in_use
,
parameterize
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
ProcessGroup
,
ColoTensor
,
ColoTensorSpec
from
tests.test_tensor.common_utils
import
debug_print
CUDA_MEM_0
=
{
False
:
512
,
True
:
1024
}
CUDA_MEM_1
=
{
False
:
0
,
True
:
1024
}
CPU_MEM
=
{
True
:
{
True
:
0
,
False
:
0
},
False
:
{
True
:
512
,
False
:
0
}}
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
@
parameterize
(
'pin_memory'
,
[
True
,
False
])
def
exam_chunk_memory
(
keep_gathered
,
pin_memory
):
pg
=
ProcessGroup
()
debug_print
([
0
],
"keep_gathered: {}, pin_memory: {}"
.
format
(
keep_gathered
,
pin_memory
))
params
=
[
ColoTensor
(
torch
.
rand
(
8
,
8
),
spec
=
ColoTensorSpec
(
pg
))
for
_
in
range
(
3
)]
config
=
{
2
:
dict
(
chunk_size
=
128
,
keep_gathered
=
keep_gathered
)
}
chunk_manager
=
ChunkManagerV2
(
config
,
pin_memory
=
pin_memory
)
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
0
for
p
in
params
:
chunk_manager
.
append_tensor
(
p
,
'param'
,
2
)
chunk_manager
.
close_all_groups
()
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
CPU_MEM
[
keep_gathered
][
pin_memory
]
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
CUDA_MEM_0
[
keep_gathered
]
chunks
=
chunk_manager
.
get_chunks
(
params
)
for
chunk
in
chunks
:
chunk_manager
.
access_chunk
(
chunk
)
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
CPU_MEM
[
keep_gathered
][
pin_memory
]
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
CUDA_MEM_0
[
True
]
for
chunk
in
chunks
:
chunk_manager
.
release_chunk
(
chunk
)
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
CPU_MEM
[
keep_gathered
][
pin_memory
]
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
CUDA_MEM_0
[
keep_gathered
]
for
chunk
in
chunks
:
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
))
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
CPU_MEM
[
keep_gathered
][
True
]
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
CUDA_MEM_1
[
keep_gathered
]
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_chunk_memory
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
@
rerun_if_address_is_in_use
()
def
test_chunk_manager
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_chunk_manager
(
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment