Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c577ed01
Unverified
Commit
c577ed01
authored
Aug 09, 2022
by
HELSON
Committed by
GitHub
Aug 09, 2022
Browse files
[zero] add AgChunk (#1417)
parent
d209aff6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
366 additions
and
0 deletions
+366
-0
colossalai/gemini/ag_chunk.py
colossalai/gemini/ag_chunk.py
+366
-0
No files found.
colossalai/gemini/ag_chunk.py
0 → 100644
View file @
c577ed01
import
torch
import
torch.distributed
as
dist
from
typing
import
Optional
,
Dict
from
colossalai.utils
import
get_current_device
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.gemini.chunk
import
TensorState
,
STATE_TRANS
,
TensorInfo
,
ChunkFullError
,
\
free_storage
,
alloc_storage
class
AgChunk
:
def
__init__
(
self
,
chunk_size
:
int
,
process_group
:
ColoProcessGroup
,
dtype
:
torch
.
dtype
,
init_device
:
Optional
[
torch
.
device
]
=
None
,
keep_gathered
:
bool
=
False
,
pin_memory
:
bool
=
False
)
->
None
:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
AgChunk is a kind of chunk, which uses all-gather operation to gather the whole chunk.
This kind of chunk is exclusively used for DDP and ZeRO DDP.
It is designed to make the full use of communication and PCIE bandwidth.
Args:
chunk_size (int): the number of elements in a chunk
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized
The default value is None, which is the current GPU
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copy in pinned CPU memory
"""
self
.
chunk_size
=
chunk_size
self
.
utilized_size
=
0
# Here, we use torch process group,
# since ColoProcessGroup might get deprecated soon
self
.
torch_pg
=
process_group
.
dp_process_group
self
.
pg_size
=
dist
.
get_world_size
(
self
.
torch_pg
)
self
.
pg_rank
=
dist
.
get_rank
(
self
.
torch_pg
)
# the chunk size should be able to be divied by the size of GPU
assert
chunk_size
%
self
.
pg_size
==
0
self
.
shard_size
=
chunk_size
//
self
.
pg_size
self
.
shard_begin
=
self
.
shard_size
*
self
.
pg_rank
self
.
shard_end
=
self
.
shard_begin
+
self
.
shard_size
self
.
dtype
=
dtype
device
=
init_device
or
get_current_device
()
self
.
chunk_temp
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
self
.
chunk_total
=
None
# we force chunk_total located in CUDA
self
.
cuda_shard
=
None
# using two attributes for the better interpretation
self
.
cpu_shard
=
None
self
.
is_gathered
=
True
self
.
chunk_mem
=
self
.
chunk_size
*
self
.
chunk_temp
.
element_size
()
self
.
shard_mem
=
self
.
chunk_mem
//
self
.
pg_size
# each tensor is associated with a TensorInfo to track meta info
self
.
tensors_info
:
Dict
[
torch
.
Tensor
,
TensorInfo
]
=
{}
# the total number of all tensors
self
.
num_tensors
=
0
# monitor the states of all tensors
self
.
tensors_state_monitor
:
Dict
[
TensorState
,
int
]
=
dict
()
for
state
in
TensorState
:
self
.
tensors_state_monitor
[
state
]
=
0
# some chunks can keep gathered all the time
# so their computation patterns are the same as that of the parameters in DDP
self
.
keep_gathered
=
keep_gathered
# if pin_memory is True, we allocate a piece of CPU pin-memory
# for it all the time
self
.
pin_memory
=
pin_memory
# we introduce the paired chunk here
# it refers to another chunk having the same parameters
# but with different dtype(such as fp16_chunk.mapping_chunk -> fp32_chunk
self
.
paired_chunk
=
None
# if the the gradient of this chunk is reduced, the flag is True
# so the flag is False for unused parameters
self
.
grad_reduced_flag
=
False
# if this chunk is synchronized with the optimizer, the flag is True
self
.
optim_sync_flag
=
True
# if the cpu_shard has been visited during the training step, the flag is True
self
.
cpu_vis_flag
=
False
@
property
def
memory_usage
(
self
):
cuda_memory
=
0
cpu_memory
=
0
if
self
.
chunk_temp
is
not
None
:
# this chunk is not closed
if
self
.
chunk_temp
.
device
.
type
==
'cuda'
:
cuda_memory
+=
self
.
chunk_mem
else
:
cpu_memory
+=
self
.
chunk_mem
else
:
if
self
.
is_gathered
:
cuda_memory
+=
self
.
chunk_mem
if
self
.
cuda_shard
is
not
None
:
cuda_memory
+=
self
.
shard_mem
if
self
.
cpu_shard
is
not
None
:
cpu_memory
+=
self
.
shard_mem
return
dict
(
cuda
=
cuda_memory
,
cpu
=
cpu_memory
)
@
property
def
device_type
(
self
):
if
self
.
chunk_temp
is
not
None
:
return
self
.
chunk_temp
.
device
.
type
else
:
if
self
.
chunk_total
is
not
None
:
return
'cuda'
elif
self
.
cuda_shard
is
not
None
:
return
'cuda'
else
:
return
'cpu'
def
append_tensor
(
self
,
tensor
:
torch
.
Tensor
):
"""Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
# sanity check
assert
self
.
chunk_temp
is
not
None
assert
tensor
.
dtype
==
self
.
dtype
new_utilized_size
=
self
.
utilized_size
+
tensor
.
numel
()
# raise exception when the chunk size is exceeded
if
new_utilized_size
>
self
.
chunk_size
:
raise
ChunkFullError
self
.
chunk_temp
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
flatten
())
assert
type
(
self
.
chunk_temp
)
==
torch
.
Tensor
,
"copy_tensor_to_chunk_slice must use a torch tensor"
tensor
.
data
=
self
.
chunk_temp
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
# record all the information about the tensor
self
.
num_tensors
+=
1
tensor_state
=
TensorState
.
HOLD
self
.
tensors_info
[
tensor
]
=
TensorInfo
(
tensor_state
,
self
.
utilized_size
,
new_utilized_size
)
self
.
tensors_state_monitor
[
tensor_state
]
+=
1
self
.
utilized_size
=
new_utilized_size
def
close_chunk
(
self
,
shard_dev
:
torch
.
device
):
"""Close the chunk. Any tensor can't be appended to a closed chunk.
"""
# sanity check
assert
self
.
chunk_temp
is
not
None
if
self
.
chunk_temp
.
device
.
type
==
'cpu'
:
self
.
chunk_total
=
self
.
chunk_temp
.
to
(
get_current_device
())
else
:
self
.
chunk_total
=
self
.
chunk_temp
self
.
chunk_temp
=
None
self
.
__scatter
()
if
self
.
pin_memory
or
shard_dev
.
type
==
'cpu'
:
self
.
cpu_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
pin_memory
=
self
.
pin_memory
)
self
.
cpu_shard
.
copy_
(
self
.
cuda_shard
)
self
.
cpu_vis_flag
=
True
# cpu_shard has been visited
if
shard_dev
.
type
==
'cpu'
:
self
.
cuda_shard
=
None
def
shard_move
(
self
,
device
:
torch
.
device
,
force_copy
:
bool
=
False
):
# sanity check
assert
not
self
.
is_gathered
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if
not
self
.
optim_sync_flag
:
assert
device
.
type
==
'cuda'
,
"each chunk should first be moved to CUDA"
self
.
__paired_shard_move
()
self
.
optim_sync_flag
=
True
return
if
device
.
type
==
'cuda'
:
assert
device
==
get_current_device
(),
"can't move chunk to another device"
if
self
.
cuda_shard
:
return
self
.
cuda_shard
=
self
.
cpu_shard
.
to
(
get_current_device
())
if
not
self
.
pin_memory
:
self
.
cpu_shard
=
None
elif
device
.
type
==
'cpu'
:
if
self
.
cuda_shard
is
None
:
return
if
self
.
pin_memory
:
if
force_copy
or
not
self
.
cpu_vis_flag
:
self
.
cpu_shard
.
copy_
(
self
.
cuda_shard
)
# if cpu_shard has been visited
# copy operation is not need
else
:
self
.
cpu_shard
=
self
.
cuda_shard
.
cpu
()
self
.
cpu_vis_flag
=
True
self
.
cuda_shard
=
None
else
:
raise
NotImplementedError
def
access_chunk
(
self
):
"""Make the chunk usable for the parameters inside it.
It is an operation done in CUDA.
"""
# sanity check
assert
self
.
chunk_temp
is
None
if
not
self
.
is_gathered
:
self
.
__gather
()
self
.
__update_tensors_ptr
()
def
release_chunk
(
self
):
"""Release the usable chunk.
It is an operation done in CUDA.
"""
# sanity check
assert
self
.
chunk_temp
is
None
if
self
.
is_gathered
:
self
.
__scatter
()
def
reduce
(
self
):
"""Reduce scatter all the gradients.
It is an operation done in CUDA.
"""
# sanity check
assert
self
.
is_gathered
if
self
.
pg_size
==
1
:
# tricky code here
# just move chunk_total to cuda_shard
# the communication is not necessary
self
.
__scatter
()
elif
self
.
keep_gathered
:
# we use all-reduce here
dist
.
all_reduce
(
self
.
chunk_total
,
group
=
self
.
torch_pg
)
else
:
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
get_current_device
())
input_list
=
list
(
torch
.
chunk
(
self
.
chunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
reduce_scatter
(
self
.
cuda_shard
,
input_list
,
self
.
torch_pg
)
free_storage
(
self
.
chunk_total
)
self
.
is_gathered
=
False
self
.
__update_tensors_state
(
TensorState
.
HOLD
)
self
.
grad_reduced_flag
=
True
def
tensor_trans_state
(
self
,
tensor
:
torch
.
Tensor
,
tensor_state
:
TensorState
)
->
None
:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if
(
self
.
tensors_info
[
tensor
].
state
,
tensor_state
)
not
in
STATE_TRANS
:
# print(
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self
.
__update_one_tensor_info
(
self
.
tensors_info
[
tensor
],
tensor_state
)
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data_slice
:
torch
.
Tensor
)
->
None
:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
# sanity check
assert
self
.
is_gathered
tensor_info
=
self
.
tensors_info
[
tensor
]
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
flatten
())
tensor
.
data
=
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
@
property
def
can_release
(
self
)
->
bool
:
return
self
.
tensors_state_monitor
[
TensorState
.
HOLD
]
==
self
.
num_tensors
@
property
def
can_reduce
(
self
):
return
self
.
tensors_state_monitor
[
TensorState
.
READY_FOR_REDUCE
]
==
self
.
num_tensors
def
__gather
(
self
):
if
not
self
.
is_gathered
:
# sanity check
assert
self
.
cuda_shard
is
not
None
if
self
.
pg_size
==
1
:
self
.
chunk_total
=
self
.
cuda_shard
else
:
alloc_storage
(
self
.
chunk_total
)
gather_list
=
list
(
torch
.
chunk
(
input
=
self
.
chunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
all_gather
(
gather_list
,
self
.
cuda_shard
,
self
.
torch_pg
)
self
.
cuda_shard
=
None
self
.
is_gathered
=
True
def
__scatter
(
self
):
if
self
.
keep_gathered
:
return
if
self
.
is_gathered
:
# sanity check
assert
self
.
cuda_shard
is
None
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
self
.
chunk_total
.
device
)
self
.
cuda_shard
.
copy_
(
self
.
chunk_total
[
self
.
shard_begin
:
self
.
shard_end
])
free_storage
(
self
.
chunk_total
)
self
.
is_gathered
=
False
def
__paired_shard_move
(
self
):
assert
self
.
paired_chunk
is
not
None
,
"chunks should be paired before training"
optim_chunk
=
self
.
paired_chunk
assert
self
.
chunk_size
==
optim_chunk
.
chunk_size
# only be called when optimizer state is in CPU memory
# the grad and param should be in the same device
assert
self
.
cuda_shard
is
None
temp
=
optim_chunk
.
cpu_shard
.
to
(
get_current_device
())
# avoid to transform FP32 in CPU
self
.
cuda_shard
=
temp
.
to
(
self
.
dtype
)
if
not
self
.
pin_memory
:
self
.
cpu_shard
=
None
def
__update_tensors_ptr
(
self
)
->
None
:
# sanity check
assert
self
.
is_gathered
assert
type
(
self
.
chunk_total
)
==
torch
.
Tensor
for
tensor
,
tensor_info
in
self
.
tensors_info
.
items
():
tensor
.
data
=
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
def
__update_one_tensor_info
(
self
,
tensor_info
:
TensorInfo
,
next_state
:
TensorState
):
self
.
tensors_state_monitor
[
tensor_info
.
state
]
-=
1
tensor_info
.
state
=
next_state
self
.
tensors_state_monitor
[
tensor_info
.
state
]
+=
1
def
__update_tensors_state
(
self
,
next_state
:
TensorState
,
prev_state
:
Optional
[
TensorState
]
=
None
):
for
tensor_info
in
self
.
tensors_info
.
values
():
if
prev_state
is
None
or
tensor_info
.
state
==
prev_state
:
self
.
__update_one_tensor_info
(
tensor_info
,
next_state
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment