Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
08f2920e
Commit
08f2920e
authored
Apr 23, 2023
by
zhuwenwen
Browse files
init colossalai, support dtk2304
parent
da3f0934
Pipeline
#237
failed with stages
in 0 seconds
Changes
380
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2325 additions
and
0 deletions
+2325
-0
colossalai/gemini/chunk/chunk.py
colossalai/gemini/chunk/chunk.py
+576
-0
colossalai/gemini/chunk/manager.py
colossalai/gemini/chunk/manager.py
+239
-0
colossalai/gemini/chunk/search_utils.py
colossalai/gemini/chunk/search_utils.py
+140
-0
colossalai/gemini/chunk/utils.py
colossalai/gemini/chunk/utils.py
+59
-0
colossalai/gemini/gemini_context.py
colossalai/gemini/gemini_context.py
+48
-0
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+156
-0
colossalai/gemini/memory_tracer/__init__.py
colossalai/gemini/memory_tracer/__init__.py
+11
-0
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
+36
-0
colossalai/gemini/memory_tracer/memory_monitor.py
colossalai/gemini/memory_tracer/memory_monitor.py
+147
-0
colossalai/gemini/memory_tracer/memory_stats.py
colossalai/gemini/memory_tracer/memory_stats.py
+135
-0
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+104
-0
colossalai/gemini/memory_tracer/param_runtime_order.py
colossalai/gemini/memory_tracer/param_runtime_order.py
+42
-0
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
+99
-0
colossalai/gemini/memory_tracer/static_memstats_collector.py
colossalai/gemini/memory_tracer/static_memstats_collector.py
+105
-0
colossalai/gemini/memory_tracer/utils.py
colossalai/gemini/memory_tracer/utils.py
+59
-0
colossalai/gemini/ophooks/__init__.py
colossalai/gemini/ophooks/__init__.py
+3
-0
colossalai/gemini/ophooks/_shard_grad_ophook.py
colossalai/gemini/ophooks/_shard_grad_ophook.py
+32
-0
colossalai/gemini/ophooks/_shard_param_ophook.py
colossalai/gemini/ophooks/_shard_param_ophook.py
+47
-0
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
+145
-0
colossalai/gemini/ophooks/utils.py
colossalai/gemini/ophooks/utils.py
+142
-0
No files found.
Too many changes to show.
To preserve performance only
380 of 380+
files are displayed.
Plain diff
Email patch
colossalai/gemini/chunk/chunk.py
0 → 100644
View file @
08f2920e
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
import
torch
import
torch.distributed
as
dist
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.utils
import
get_current_device
class
TensorState
(
Enum
):
FREE
=
0
COMPUTE
=
1
HOLD
=
2
HOLD_AFTER_BWD
=
3
READY_FOR_REDUCE
=
4
STATE_TRANS
=
((
TensorState
.
FREE
,
TensorState
.
HOLD
),
(
TensorState
.
FREE
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD
,
TensorState
.
FREE
),
(
TensorState
.
HOLD
,
TensorState
.
COMPUTE
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD_AFTER_BWD
),
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
READY_FOR_REDUCE
,
TensorState
.
HOLD
))
@
dataclass
class
TensorInfo
:
state
:
TensorState
offset
:
int
end
:
int
class
ChunkFullError
(
Exception
):
pass
def
is_storage_empty
(
tensor
:
torch
.
Tensor
)
->
bool
:
return
tensor
.
storage
().
size
()
==
0
def
free_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
not
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
0
)
def
alloc_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
tensor
.
numel
())
class
Chunk
:
_total_number
=
0
def
__init__
(
self
,
chunk_size
:
int
,
process_group
:
ColoProcessGroup
,
dtype
:
torch
.
dtype
,
init_device
:
Optional
[
torch
.
device
]
=
None
,
cpu_shard_init
:
bool
=
False
,
keep_gathered
:
bool
=
False
,
pin_memory
:
bool
=
False
)
->
None
:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
Here we use all-gather operation to gather the whole chunk.
Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters.
It is designed to make the full use of communication and PCIE bandwidth.
Args:
chunk_size (int): the number of elements in the chunk
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU
cpu_shard_init (bool): a flag indicates the local chunk shard is resident on CPU.
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
"""
self
.
count_id
=
Chunk
.
_total_number
Chunk
.
_total_number
+=
1
self
.
chunk_size
=
chunk_size
self
.
utilized_size
=
0
self
.
torch_pg
=
process_group
.
dp_process_group
()
self
.
pg_size
=
dist
.
get_world_size
(
self
.
torch_pg
)
self
.
pg_rank
=
dist
.
get_rank
(
self
.
torch_pg
)
# the chunk size should be divisible by the dp degree
if
not
keep_gathered
:
assert
chunk_size
%
self
.
pg_size
==
0
self
.
shard_size
=
chunk_size
//
self
.
pg_size
self
.
shard_begin
=
self
.
shard_size
*
self
.
pg_rank
self
.
shard_end
=
self
.
shard_begin
+
self
.
shard_size
self
.
valid_end
=
self
.
shard_size
self
.
dtype
=
dtype
device
=
init_device
or
get_current_device
()
# chunk_temp is a global chunk, which only exists during building the chunks.
self
.
chunk_temp
=
torch
.
zeros
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
# keep all zero
self
.
cuda_global_chunk
=
None
# we force cuda_global_chunk located in CUDA
# cuda local chunk, which is sharded on GPUs
self
.
cuda_shard
=
None
# cpu local chunk, which is sharded on CPUs
self
.
cpu_shard
=
None
# is the chunks gathers, which means chunks are duplicated on each process,
# and we should use the cuda_global_chunk.
self
.
is_gathered
=
True
# configure the init device of the shard
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
self
.
shard_device
=
torch
.
device
(
"cpu"
)
if
cpu_shard_init
else
get_current_device
()
self
.
chunk_mem
=
self
.
chunk_size
*
self
.
chunk_temp
.
element_size
()
self
.
shard_mem
=
self
.
chunk_mem
//
self
.
pg_size
# each tensor is associated with a TensorInfo to track its meta info
# (state, offset, end)
self
.
tensors_info
:
Dict
[
torch
.
Tensor
,
TensorInfo
]
=
{}
# the total number of tensors in the chunk
self
.
num_tensors
=
0
# Record the number of tensors in different states
self
.
tensor_state_cnter
:
Dict
[
TensorState
,
int
]
=
dict
()
for
state
in
TensorState
:
self
.
tensor_state_cnter
[
state
]
=
0
# If a chunk is kept gathered,
# they are treated the same as that of the parameters in DDP during training.
self
.
keep_gathered
=
keep_gathered
if
self
.
keep_gathered
:
pin_memory
=
False
# since this chunk is gathered, it doesn't need to pin
# if pin_memory is True, we allocate a piece of CPU pin-memory
# for it all the time
self
.
pin_memory
=
pin_memory
# we introduce the paired chunk here
# it refers to another chunk having the same parameters
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
self
.
paired_chunk
=
None
# if this chunk is synchronized with the optimizer, the flag is True
self
.
optim_sync_flag
=
True
# if the cpu_shard has been visited during the training step, the flag is True
self
.
cpu_vis_flag
=
False
# whether to record l2 norm for the gradient clipping calculation
self
.
l2_norm_flag
=
False
self
.
l2_norm
=
None
@
property
def
memory_usage
(
self
)
->
Dict
[
str
,
int
]:
cuda_memory
=
0
cpu_memory
=
0
if
self
.
chunk_temp
is
not
None
:
# this chunk is not closed
if
self
.
chunk_temp
.
device
.
type
==
'cuda'
:
cuda_memory
+=
self
.
chunk_mem
else
:
cpu_memory
+=
self
.
chunk_mem
else
:
if
self
.
is_gathered
:
cuda_memory
+=
self
.
chunk_mem
if
self
.
cuda_shard
is
not
None
:
cuda_memory
+=
self
.
shard_mem
if
self
.
cpu_shard
is
not
None
:
cpu_memory
+=
self
.
shard_mem
return
dict
(
cuda
=
cuda_memory
,
cpu
=
cpu_memory
)
@
property
def
device_type
(
self
)
->
str
:
if
self
.
chunk_temp
is
not
None
:
return
self
.
chunk_temp
.
device
.
type
else
:
if
self
.
is_gathered
:
return
'cuda'
elif
self
.
cuda_shard
is
not
None
:
return
'cuda'
else
:
return
'cpu'
@
property
def
payload
(
self
)
->
torch
.
Tensor
:
# sanity check
assert
self
.
chunk_temp
is
None
if
self
.
is_gathered
:
return
self
.
cuda_global_chunk
elif
self
.
cuda_shard
is
not
None
:
return
self
.
cuda_shard
else
:
return
self
.
cpu_shard
@
property
def
payload_mem
(
self
)
->
int
:
# sanity check
assert
self
.
chunk_temp
is
None
if
self
.
is_gathered
:
return
self
.
chunk_mem
else
:
return
self
.
shard_mem
@
property
def
can_move
(
self
)
->
bool
:
return
not
self
.
is_gathered
@
property
def
can_release
(
self
)
->
bool
:
if
self
.
keep_gathered
:
return
False
else
:
return
self
.
tensor_state_cnter
[
TensorState
.
HOLD
]
+
\
self
.
tensor_state_cnter
[
TensorState
.
HOLD_AFTER_BWD
]
==
self
.
num_tensors
@
property
def
can_reduce
(
self
):
return
self
.
tensor_state_cnter
[
TensorState
.
READY_FOR_REDUCE
]
==
self
.
num_tensors
@
property
def
has_inf_or_nan
(
self
)
->
bool
:
"""Check if the chunk has inf or nan values on CUDA.
"""
if
self
.
is_gathered
:
valid_tensor
=
self
.
cuda_global_chunk
[:
self
.
utilized_size
]
else
:
assert
self
.
cuda_shard
is
not
None
# only check on CUDA
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
return
torch
.
isinf
(
valid_tensor
).
any
().
item
()
|
torch
.
isnan
(
valid_tensor
).
any
().
item
()
def
set_l2_norm
(
self
)
->
None
:
"""Record l2 norm of this chunks on CUDA.
"""
assert
self
.
l2_norm
is
None
,
"you are calculating the l2 norm twice"
if
self
.
is_gathered
:
valid_tensor
=
self
.
cuda_global_chunk
[:
self
.
utilized_size
]
else
:
assert
self
.
cuda_shard
is
not
None
# calculate on CUDA
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
chunk_l2_norm
=
valid_tensor
.
data
.
float
().
norm
(
2
)
self
.
l2_norm
=
chunk_l2_norm
.
item
()
**
2
def
append_tensor
(
self
,
tensor
:
torch
.
Tensor
):
"""Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
# sanity check
assert
self
.
chunk_temp
is
not
None
assert
tensor
.
dtype
==
self
.
dtype
new_utilized_size
=
self
.
utilized_size
+
tensor
.
numel
()
# raise exception when the chunk size is exceeded
if
new_utilized_size
>
self
.
chunk_size
:
raise
ChunkFullError
self
.
chunk_temp
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
data
.
flatten
())
assert
type
(
self
.
chunk_temp
)
==
torch
.
Tensor
,
"copy_tensor_to_chunk_slice must use a torch tensor"
tensor
.
data
=
self
.
chunk_temp
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
# record all the information about the tensor
self
.
num_tensors
+=
1
tensor_state
=
TensorState
.
HOLD
self
.
tensors_info
[
tensor
]
=
TensorInfo
(
tensor_state
,
self
.
utilized_size
,
new_utilized_size
)
self
.
tensor_state_cnter
[
tensor_state
]
+=
1
self
.
utilized_size
=
new_utilized_size
def
close_chunk
(
self
):
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
"""
# sanity check
assert
self
.
chunk_temp
is
not
None
# calculate the valid end for each shard
if
self
.
utilized_size
<=
self
.
shard_begin
:
self
.
valid_end
=
0
elif
self
.
utilized_size
<
self
.
shard_end
:
self
.
valid_end
=
self
.
utilized_size
-
self
.
shard_begin
if
self
.
chunk_temp
.
device
.
type
==
'cpu'
:
self
.
cuda_global_chunk
=
self
.
chunk_temp
.
to
(
get_current_device
())
self
.
__update_tensors_ptr
()
else
:
self
.
cuda_global_chunk
=
self
.
chunk_temp
self
.
chunk_temp
=
None
self
.
__scatter
()
# gathered chunk never have shard attribute
if
self
.
keep_gathered
:
return
if
self
.
pin_memory
or
self
.
shard_device
.
type
==
'cpu'
:
self
.
cpu_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
pin_memory
=
self
.
pin_memory
)
self
.
cpu_shard
.
copy_
(
self
.
cuda_shard
)
self
.
cpu_vis_flag
=
True
# cpu_shard has been visited
if
self
.
shard_device
.
type
==
'cpu'
:
self
.
cuda_shard
=
None
def
shard_move
(
self
,
device
:
torch
.
device
,
force_copy
:
bool
=
False
):
"""Move the shard tensor in the chunk.
Args:
device: the device to which the shard will move
force_copy: if True, copy function is called mandatorily
"""
# sanity check
assert
not
self
.
is_gathered
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if
not
self
.
optim_sync_flag
:
assert
device
.
type
==
'cuda'
,
"each chunk should first be moved to CUDA"
self
.
__paired_shard_move
()
self
.
optim_sync_flag
=
True
return
if
device
.
type
==
'cuda'
:
assert
device
==
get_current_device
(),
"can't move chunk to another device"
if
self
.
cuda_shard
:
return
self
.
cuda_shard
=
self
.
cpu_shard
.
to
(
get_current_device
())
if
not
self
.
pin_memory
:
self
.
cpu_shard
=
None
elif
device
.
type
==
'cpu'
:
if
self
.
cuda_shard
is
None
:
return
if
self
.
pin_memory
:
if
force_copy
or
not
self
.
cpu_vis_flag
:
self
.
cpu_shard
.
copy_
(
self
.
cuda_shard
)
# if cpu_shard has been visited
# copy operation is not need
else
:
self
.
cpu_shard
=
self
.
cuda_shard
.
cpu
()
self
.
cpu_vis_flag
=
True
self
.
cuda_shard
=
None
else
:
raise
NotImplementedError
def
access_chunk
(
self
):
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
"""
# sanity check
assert
self
.
chunk_temp
is
None
if
not
self
.
is_gathered
:
self
.
__gather
()
self
.
__update_tensors_ptr
()
def
release_chunk
(
self
):
"""Release the usable chunk. It's an operation done in CUDA.
"""
# sanity check
assert
self
.
chunk_temp
is
None
if
self
.
is_gathered
:
self
.
__scatter
()
def
reduce
(
self
):
"""Reduce scatter all the gradients. It's an operation done in CUDA.
"""
# sanity check
assert
self
.
is_gathered
if
self
.
pg_size
==
1
:
# tricky code here
# just move cuda_global_chunk to cuda_shard
# the communication is not necessary
self
.
__scatter
()
elif
self
.
keep_gathered
:
# we use all-reduce here
dist
.
all_reduce
(
self
.
cuda_global_chunk
,
group
=
self
.
torch_pg
)
else
:
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
get_current_device
())
input_list
=
list
(
torch
.
chunk
(
self
.
cuda_global_chunk
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
reduce_scatter
(
self
.
cuda_shard
,
input_list
,
group
=
self
.
torch_pg
)
free_storage
(
self
.
cuda_global_chunk
)
self
.
is_gathered
=
False
self
.
__update_tensors_state
(
TensorState
.
HOLD
)
def
tensor_trans_state
(
self
,
tensor
:
torch
.
Tensor
,
tensor_state
:
TensorState
)
->
None
:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if
(
self
.
tensors_info
[
tensor
].
state
,
tensor_state
)
not
in
STATE_TRANS
:
return
self
.
__update_one_tensor_info
(
self
.
tensors_info
[
tensor
],
tensor_state
)
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data_slice
:
torch
.
Tensor
)
->
None
:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
# sanity check
assert
self
.
is_gathered
tensor_info
=
self
.
tensors_info
[
tensor
]
self
.
cuda_global_chunk
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
data
.
flatten
())
tensor
.
data
=
self
.
cuda_global_chunk
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
def
get_valid_length
(
self
)
->
int
:
"""Get the valid length of the chunk's payload.
"""
if
self
.
keep_gathered
:
return
self
.
utilized_size
else
:
return
self
.
valid_end
def
init_pair
(
self
,
friend_chunk
:
'Chunk'
)
->
None
:
"""Initialize the paired chunk.
"""
if
self
.
paired_chunk
is
None
and
friend_chunk
.
paired_chunk
is
None
:
self
.
paired_chunk
=
friend_chunk
friend_chunk
.
paired_chunk
=
self
else
:
assert
self
.
paired_chunk
is
friend_chunk
assert
friend_chunk
.
paired_chunk
is
self
def
optim_update
(
self
)
->
None
:
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
"""
# sanity check
assert
self
.
paired_chunk
is
not
None
friend_chunk
=
self
.
paired_chunk
if
self
.
is_gathered
is
True
:
assert
friend_chunk
.
is_gathered
is
True
self
.
cuda_global_chunk
.
copy_
(
friend_chunk
.
cuda_global_chunk
)
self
.
optim_sync_flag
=
True
elif
friend_chunk
.
device_type
==
'cuda'
and
self
.
device_type
==
'cuda'
:
self
.
cuda_shard
.
copy_
(
friend_chunk
.
cuda_shard
)
self
.
optim_sync_flag
=
True
self
.
cpu_vis_flag
=
False
else
:
# optim_sync_flag is set to False
# see shard_move function for more details
assert
friend_chunk
.
device_type
==
'cpu'
assert
self
.
device_type
==
'cpu'
self
.
optim_sync_flag
=
False
self
.
cpu_vis_flag
=
False
def
get_tensors
(
self
)
->
List
[
torch
.
Tensor
]:
return
list
(
self
.
tensors_info
.
keys
())
def
__gather
(
self
):
if
not
self
.
is_gathered
:
# sanity check
assert
self
.
cuda_shard
is
not
None
alloc_storage
(
self
.
cuda_global_chunk
)
gather_list
=
list
(
torch
.
chunk
(
input
=
self
.
cuda_global_chunk
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
all_gather
(
gather_list
,
self
.
cuda_shard
,
self
.
torch_pg
)
self
.
cuda_shard
=
None
self
.
is_gathered
=
True
def
__scatter
(
self
):
if
self
.
keep_gathered
:
return
if
self
.
is_gathered
:
# sanity check
assert
self
.
cuda_shard
is
None
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
self
.
cuda_global_chunk
.
device
)
self
.
cuda_shard
.
copy_
(
self
.
cuda_global_chunk
[
self
.
shard_begin
:
self
.
shard_end
])
free_storage
(
self
.
cuda_global_chunk
)
self
.
is_gathered
=
False
def
__paired_shard_move
(
self
):
assert
self
.
paired_chunk
is
not
None
,
"chunks should be paired before training"
optim_chunk
=
self
.
paired_chunk
assert
self
.
chunk_size
==
optim_chunk
.
chunk_size
# only be called when optimizer state is in CPU memory
# the grad and param should be in the same device
assert
self
.
cuda_shard
is
None
temp
=
optim_chunk
.
cpu_shard
.
to
(
get_current_device
())
# avoid to transform FP32 in CPU
self
.
cuda_shard
=
temp
.
to
(
self
.
dtype
)
if
not
self
.
pin_memory
:
self
.
cpu_shard
=
None
def
__update_tensors_ptr
(
self
)
->
None
:
# sanity check
assert
self
.
is_gathered
assert
type
(
self
.
cuda_global_chunk
)
==
torch
.
Tensor
for
tensor
,
tensor_info
in
self
.
tensors_info
.
items
():
tensor
.
data
=
self
.
cuda_global_chunk
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
def
__update_one_tensor_info
(
self
,
tensor_info
:
TensorInfo
,
next_state
:
TensorState
):
self
.
tensor_state_cnter
[
tensor_info
.
state
]
-=
1
tensor_info
.
state
=
next_state
self
.
tensor_state_cnter
[
tensor_info
.
state
]
+=
1
def
__update_tensors_state
(
self
,
next_state
:
TensorState
,
prev_state
:
Optional
[
TensorState
]
=
None
):
for
tensor_info
in
self
.
tensors_info
.
values
():
if
prev_state
is
None
or
tensor_info
.
state
==
prev_state
:
self
.
__update_one_tensor_info
(
tensor_info
,
next_state
)
def
__hash__
(
self
)
->
int
:
return
hash
(
id
(
self
))
def
__eq__
(
self
,
__o
:
object
)
->
bool
:
return
self
is
__o
def
__repr__
(
self
,
detailed
:
bool
=
True
):
output
=
[
"Chunk Information:
\n
"
,
"
\t
chunk size: {}, chunk dtype: {}, process group size: {}
\n
"
.
format
(
self
.
chunk_size
,
self
.
dtype
,
self
.
pg_size
),
"
\t
# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}
\n
"
.
format
(
self
.
num_tensors
,
self
.
utilized_size
,
self
.
utilized_size
/
self
.
chunk_size
)
]
def
print_tensor
(
tensor
,
prefix
=
''
):
output
.
append
(
"{}shape: {}, dtype: {}, device: {}
\n
"
.
format
(
prefix
,
tensor
.
shape
,
tensor
.
dtype
,
tensor
.
device
))
if
self
.
chunk_temp
is
not
None
:
output
.
append
(
"
\t
chunk temp:
\n
"
)
print_tensor
(
tensor
=
self
.
chunk_temp
,
prefix
=
'
\t\t
'
)
if
self
.
cuda_global_chunk
is
not
None
and
self
.
cuda_global_chunk
.
storage
().
size
()
>
0
:
output
.
append
(
"
\t
chunk total:
\n
"
)
print_tensor
(
tensor
=
self
.
cuda_global_chunk
,
prefix
=
'
\t\t
'
)
if
self
.
cuda_shard
is
not
None
:
output
.
append
(
"
\t
cuda shard:
\n
"
)
print_tensor
(
tensor
=
self
.
cuda_shard
,
prefix
=
'
\t\t
'
)
if
self
.
cpu_shard
is
not
None
:
output
.
append
(
"
\t
cpu shard:
\n
"
)
print_tensor
(
tensor
=
self
.
cpu_shard
,
prefix
=
'
\t\t
'
)
memory_info
=
self
.
memory_usage
output
.
append
(
"
\t
memory usage: cuda {}, cpu {}
\n
"
.
format
(
memory_info
[
'cuda'
],
memory_info
[
'cpu'
]))
if
detailed
:
output
.
append
(
"
\t
tensor state monitor:
\n
"
)
for
st
in
TensorState
:
output
.
append
(
"
\t\t
# of {}: {}
\n
"
.
format
(
st
,
self
.
tensor_state_cnter
[
st
]))
return
''
.
join
(
output
)
colossalai/gemini/chunk/manager.py
0 → 100644
View file @
08f2920e
from
collections
import
deque
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
import
torch
from
colossalai.gemini.chunk
import
Chunk
,
ChunkFullError
,
TensorState
from
colossalai.tensor
import
ColoTensor
from
colossalai.utils
import
get_current_device
class
ChunkManager
:
"""
A manager class to manipulate the tensors in chunks.
Args:
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""
def
__init__
(
self
,
chunk_configuration
,
init_device
:
Optional
[
torch
.
device
]
=
None
)
->
None
:
self
.
device
=
init_device
or
get_current_device
()
self
.
dp_degree_chunk_size_dict
:
Dict
[
int
,
int
]
=
dict
()
self
.
kwargs_config
=
chunk_configuration
for
k
,
v
in
self
.
kwargs_config
.
items
():
self
.
dp_degree_chunk_size_dict
[
k
]
=
v
.
pop
(
'chunk_size'
)
v
[
'init_device'
]
=
self
.
device
self
.
chunk_groups
:
Dict
[
str
,
Deque
]
=
dict
()
self
.
tensor_chunk_map
:
Dict
[
torch
.
Tensor
,
Chunk
]
=
dict
()
self
.
accessed_chunks
:
Set
[
Chunk
]
=
set
()
self
.
accessed_mem
:
int
=
0
self
.
total_mem
:
Dict
[
str
,
int
]
=
{
'cpu'
:
0
,
'cuda'
:
0
}
def
register_tensor
(
self
,
tensor
:
ColoTensor
,
group_type
:
str
,
config_key
:
int
,
cpu_offload
:
bool
=
False
,
pin_memory
:
bool
=
False
)
->
None
:
"""
Register a tensor to the chunk manager.
Then, the tensor should be accessed by `get_chunks`.
Args:
tensor: the tensor appended to the chunk
group_type: the data type of the group.
config_key: the key of the group's name, the size of the dp world
cpu_offload: if True, the chunk will be closed on CPU
pin_memory: whether the chunk is pinned in the cpu memory
"""
assert
tensor
not
in
self
.
tensor_chunk_map
assert
isinstance
(
tensor
,
ColoTensor
),
"Please feed ColoTensor to this ChunkManager"
assert
config_key
in
self
.
dp_degree_chunk_size_dict
chunk_size
=
self
.
dp_degree_chunk_size_dict
[
config_key
]
chunk_kwargs
=
self
.
kwargs_config
[
config_key
]
group_name
=
"{}_{}"
.
format
(
group_type
,
config_key
)
chunk_group
=
self
.
__get_chunk_group
(
group_name
)
try
:
# append the tensor to the last chunk
chunk_group
[
-
1
].
append_tensor
(
tensor
)
except
(
IndexError
,
ChunkFullError
):
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
if
chunk_group
:
# the chunk group is not empty
# close the last chunk
self
.
__close_one_chunk
(
chunk_group
[
-
1
])
if
tensor
.
numel
()
>
chunk_size
:
chunk_size
=
tensor
.
numel
()
chunk
=
Chunk
(
chunk_size
=
chunk_size
,
process_group
=
tensor
.
process_group
,
dtype
=
tensor
.
dtype
,
cpu_shard_init
=
cpu_offload
,
pin_memory
=
pin_memory
,
**
chunk_kwargs
,
)
chunk_group
.
append
(
chunk
)
chunk
.
append_tensor
(
tensor
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
self
.
tensor_chunk_map
[
tensor
]
=
chunk_group
[
-
1
]
def
close_all_groups
(
self
):
"""Close all the chunks of all groups.
"""
for
group_name
in
self
.
chunk_groups
:
self
.
__close_one_chunk
(
self
.
chunk_groups
[
group_name
][
-
1
])
def
access_chunk
(
self
,
chunk
:
Chunk
)
->
None
:
"""Make the chunk can be used for calculation.
"""
if
chunk
in
self
.
accessed_chunks
:
return
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
if
chunk
.
device_type
==
'cpu'
:
chunk
.
shard_move
(
get_current_device
())
self
.
__add_accessed_chunk
(
chunk
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
def
release_chunk
(
self
,
chunk
:
Chunk
)
->
None
:
"""Scatter the chunk in CUDA.
"""
if
chunk
not
in
self
.
accessed_chunks
:
return
if
chunk
.
can_release
:
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
self
.
__sub_accessed_chunk
(
chunk
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
def
move_chunk
(
self
,
chunk
:
Chunk
,
device
:
torch
.
device
,
force_copy
:
bool
=
False
)
->
None
:
"""Move the shard of the chunk to the target device.
"""
if
not
chunk
.
can_move
or
chunk
.
device_type
==
device
.
type
:
return
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
chunk
.
shard_move
(
device
,
force_copy
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
def
trans_tensor_state
(
self
,
tensor
:
torch
.
Tensor
,
state
:
TensorState
)
->
None
:
"""Transit tensor state according to pre-defined state machine.
"""
chunk
=
self
.
tensor_chunk_map
[
tensor
]
chunk
.
tensor_trans_state
(
tensor
,
state
)
def
reduce_chunk
(
self
,
chunk
:
Chunk
)
->
bool
:
"""Reduce or all reduce the chunk.
"""
if
not
chunk
.
can_reduce
:
return
False
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
chunk
.
reduce
()
self
.
__sub_accessed_chunk
(
chunk
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
return
True
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data
:
torch
.
Tensor
)
->
None
:
"""
Copy data to the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data (torch.Tensor): the tensor to be copied to the chunk
"""
chunk
=
self
.
tensor_chunk_map
[
tensor
]
chunk
.
copy_tensor_to_chunk_slice
(
tensor
,
data
)
def
get_chunk
(
self
,
tensor
:
torch
.
Tensor
)
->
Chunk
:
"""
Return the chunk owning the tensor.
Args:
tensor (torch.Tensor): a torch tensor object
"""
return
self
.
tensor_chunk_map
[
tensor
]
def
get_cuda_movable_chunks
(
self
)
->
List
[
Chunk
]:
"""
Get all chunks that can be moved.
"""
chunk_list
=
[]
for
chunk
in
self
.
accessed_chunks
:
if
chunk
.
can_release
:
chunk_list
.
append
(
chunk
)
chunk_list
.
sort
(
key
=
lambda
x
:
x
.
count_id
)
return
chunk_list
def
get_chunks
(
self
,
tensors
:
Iterable
[
torch
.
Tensor
])
->
Tuple
[
Chunk
,
...]:
"""
Get all chunks owning the input tensors.
Args:
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
"""
chunks
=
[]
for
tensor
in
tensors
:
chunk
=
self
.
get_chunk
(
tensor
)
if
chunk
not
in
chunks
:
chunks
.
append
(
chunk
)
return
tuple
(
chunks
)
def
add_extern_static_tensor
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
"""Add extern static tensor to chunk manager.
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
They are "static", which means their shape, dtype, device never change.
Thus, their memory usage never changes.
Args:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert
tensor
not
in
self
.
tensor_chunk_map
self
.
total_mem
[
tensor
.
device
.
type
]
+=
tensor
.
numel
()
*
tensor
.
element_size
()
def
__repr__
(
self
)
->
str
:
msg
=
[
'Chunk Manager Information:
\n
'
,
'Total memory: '
+
', '
.
join
([
f
'
{
k
}
=
{
v
}
B'
for
k
,
v
in
self
.
total_mem
.
items
()])
+
'
\n
'
]
for
group_name
,
group
in
self
.
chunk_groups
.
items
():
msg
.
append
(
f
'Group
{
group_name
}
:
\n
'
)
for
i
,
chunk
in
enumerate
(
group
):
msg
.
append
(
f
'[
{
i
}
]
{
chunk
}
\n
'
)
return
''
.
join
(
msg
)
def
__get_chunk_group
(
self
,
group_name
:
str
)
->
Deque
:
"""Register a chunk group.
"""
if
group_name
not
in
self
.
chunk_groups
:
self
.
chunk_groups
[
group_name
]
=
deque
()
return
self
.
chunk_groups
[
group_name
]
def
__close_one_chunk
(
self
,
chunk
:
Chunk
):
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
chunk
.
close_chunk
()
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
def
__sub_memroy_usage
(
self
,
usage
:
Dict
[
str
,
int
]):
for
k
,
v
in
usage
.
items
():
self
.
total_mem
[
k
]
-=
v
def
__add_memory_usage
(
self
,
usage
:
Dict
[
str
,
int
]):
for
k
,
v
in
usage
.
items
():
self
.
total_mem
[
k
]
+=
v
def
__add_accessed_chunk
(
self
,
chunk
:
Chunk
):
chunk
.
access_chunk
()
self
.
accessed_chunks
.
add
(
chunk
)
self
.
accessed_mem
+=
chunk
.
chunk_mem
def
__sub_accessed_chunk
(
self
,
chunk
:
Chunk
):
chunk
.
release_chunk
()
self
.
accessed_chunks
.
remove
(
chunk
)
self
.
accessed_mem
-=
chunk
.
chunk_mem
colossalai/gemini/chunk/search_utils.py
0 → 100644
View file @
08f2920e
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch.nn
as
nn
from
colossalai.gemini.memory_tracer
import
MemStats
,
OrderedParamGenerator
from
colossalai.tensor
import
ColoParameter
def
in_ddp
(
param
:
nn
.
Parameter
)
->
bool
:
return
not
getattr
(
param
,
'_ddp_to_ignore'
,
False
)
def
_filter_exlarge_params
(
model
:
nn
.
Module
,
size_dict
:
Dict
[
int
,
List
[
int
]])
->
None
:
"""
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
"""
params_size
=
[
p
.
numel
()
for
p
in
model
.
parameters
()
if
in_ddp
(
p
)]
params_size_arr
=
np
.
array
(
params_size
)
std
=
np
.
std
(
params_size_arr
)
mean
=
np
.
mean
(
params_size_arr
)
upper_limit
=
mean
+
3
*
std
for
key
in
size_dict
:
org_list
=
size_dict
[
key
]
size_dict
[
key
]
=
list
(
filter
(
lambda
x
:
x
<=
upper_limit
,
org_list
))
def
_get_unused_byte
(
size_list
:
List
[
int
],
chunk_size
:
int
)
->
int
:
"""Get unused byte for a certain chunk size.
"""
acc
=
0
left
=
0
for
s
in
size_list
:
if
s
>
left
:
acc
+=
left
left
=
chunk_size
left
-=
s
return
left
+
acc
def
classify_params_by_dp_degree
(
param_order
:
OrderedParamGenerator
)
->
Dict
[
int
,
List
[
ColoParameter
]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
Args:
param_order (OrderedParamGenerator): the order of param be visied
Returns:
Dict[int, List[ColoParameter]]: a dict contains the classification results.
The keys are dp_degrees and the values are parameters.
"""
params_dict
:
Dict
[
int
,
List
[
ColoParameter
]]
=
dict
()
for
param
in
param_order
.
generate
():
assert
isinstance
(
param
,
ColoParameter
),
"please init model in the ColoInitContext"
if
not
in_ddp
(
param
):
continue
param_key
=
param
.
process_group
.
dp_world_size
()
if
param_key
not
in
params_dict
:
params_dict
[
param_key
]
=
[]
params_dict
[
param_key
].
append
(
param
)
return
params_dict
def
search_chunk_configuration
(
model
:
nn
.
Module
,
search_range_mb
:
float
,
search_interval_byte
:
int
,
# hidden size is the best value for the interval
min_chunk_size_mb
:
float
=
32
,
filter_exlarge_params
:
bool
=
True
,
memstas
:
Optional
[
MemStats
]
=
None
)
->
Tuple
[
Dict
,
int
]:
"""search_chunk_configuration
Args:
model (nn.Module): torch module
search_range_mb (float): searching range in mega byte.
search_interval_byte (int): searching interval in byte.
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
Returns:
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
"""
if
memstas
is
not
None
:
param_order
=
memstas
.
param_order
()
else
:
# build the param visited order right now
param_order
=
OrderedParamGenerator
()
for
p
in
model
.
parameters
():
param_order
.
append
(
p
)
search_range_byte
=
round
(
search_range_mb
*
1024
**
2
)
min_chunk_size_byte
=
round
(
min_chunk_size_mb
*
1024
**
2
)
assert
search_range_byte
>=
0
params_dict
=
classify_params_by_dp_degree
(
param_order
)
config_dict
:
Dict
[
int
,
Dict
]
=
dict
()
size_dict
:
Dict
[
int
,
List
[
int
]]
=
dict
()
for
dp_degree
in
params_dict
:
params_list
=
params_dict
[
dp_degree
]
size_list
=
[
p
.
numel
()
for
p
in
params_list
]
# let small parameters keep gathered in CUDA all the time
total_size
=
sum
(
size_list
)
if
total_size
<
min_chunk_size_byte
:
config_dict
[
dp_degree
]
=
dict
(
chunk_size
=
total_size
,
keep_gathered
=
True
)
else
:
size_dict
[
dp_degree
]
=
size_list
if
filter_exlarge_params
:
_filter_exlarge_params
(
model
,
size_dict
)
max_size
=
min_chunk_size_byte
for
key
in
size_dict
:
max_size
=
max
(
max_size
,
max
(
size_dict
[
key
]))
start_size
=
int
(
math
.
ceil
(
max_size
/
search_interval_byte
)
*
search_interval_byte
)
min_chunk_waste
=
float
(
'+inf'
)
best_chunk_size
=
start_size
for
chunk_size
in
range
(
start_size
,
start_size
+
search_range_byte
+
1
,
search_interval_byte
):
temp_waste
=
0
for
key
in
size_dict
:
temp_waste
+=
_get_unused_byte
(
size_dict
[
key
],
chunk_size
)
if
temp_waste
<
min_chunk_waste
:
min_chunk_waste
=
temp_waste
best_chunk_size
=
chunk_size
for
dp_degree
in
params_dict
:
if
dp_degree
in
config_dict
:
continue
config_dict
[
dp_degree
]
=
dict
(
chunk_size
=
best_chunk_size
,
keep_gathered
=
False
)
return
config_dict
,
min_chunk_waste
colossalai/gemini/chunk/utils.py
0 → 100644
View file @
08f2920e
from
time
import
time
from
typing
import
Optional
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
colossalai.gemini.chunk
import
ChunkManager
from
colossalai.gemini.chunk.search_utils
import
in_ddp
,
search_chunk_configuration
from
colossalai.gemini.memory_tracer
import
MemStats
def
init_chunk_manager
(
model
:
nn
.
Module
,
init_device
:
Optional
[
torch
.
device
]
=
None
,
hidden_dim
:
Optional
[
int
]
=
None
,
search_range_mb
:
Optional
[
float
]
=
None
,
min_chunk_size_mb
:
Optional
[
float
]
=
None
,
filter_exlarge_params
:
Optional
[
bool
]
=
None
)
->
ChunkManager
:
kwargs_dict
=
dict
()
if
hidden_dim
:
search_interval_byte
=
hidden_dim
else
:
search_interval_byte
=
1024
# 1kb
kwargs_dict
[
"search_interval_byte"
]
=
search_interval_byte
if
search_range_mb
:
kwargs_dict
[
"search_range_mb"
]
=
search_range_mb
if
min_chunk_size_mb
:
kwargs_dict
[
"min_chunk_size_mb"
]
=
min_chunk_size_mb
if
filter_exlarge_params
:
kwargs_dict
[
"filter_exlarge_params"
]
=
filter_exlarge_params
params_sizes
=
[
p
.
numel
()
for
p
in
model
.
parameters
()
if
in_ddp
(
p
)]
total_size
=
sum
(
params_sizes
)
/
1024
**
2
dist
.
barrier
()
begin
=
time
()
config_dict
,
wasted_size
=
search_chunk_configuration
(
model
,
**
kwargs_dict
)
dist
.
barrier
()
end
=
time
()
span_s
=
end
-
begin
wasted_size
/=
1024
**
2
if
dist
.
get_rank
()
==
0
:
print
(
"searching chunk configuration is completed in {:.2f} s.
\n
"
.
format
(
span_s
),
"used number: {:.2f} MB, wasted number: {:.2f} MB
\n
"
.
format
(
total_size
,
wasted_size
),
"total wasted percentage is {:.2f}%"
.
format
(
100
*
wasted_size
/
(
total_size
+
wasted_size
)),
sep
=
''
,
flush
=
True
)
dist
.
barrier
()
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
)
return
chunk_manager
colossalai/gemini/gemini_context.py
0 → 100644
View file @
08f2920e
from
enum
import
EnumMeta
class
GeminiMemoryManager
(
object
):
def
__init__
(
self
,
states_cls
:
EnumMeta
):
super
().
__init__
()
self
.
states_cls
=
states_cls
self
.
_cnter
=
0
# the counter of instances
self
.
total_mem
=
dict
()
self
.
state_mem
=
dict
()
self
.
state_mem
[
'cpu'
]
=
dict
()
self
.
state_mem
[
'cuda'
]
=
dict
()
self
.
reset
()
@
property
def
total_number
(
self
):
return
self
.
_cnter
def
reset
(
self
):
self
.
_cnter
=
0
# the counter of instances
self
.
total_mem
[
'cpu'
]
=
0
# memory occupation of instances in cpu
self
.
total_mem
[
'cuda'
]
=
0
# memory of occupation of instances in cuda
# memory conditions for all states
for
state
in
self
.
states_cls
:
self
.
state_mem
[
'cpu'
][
state
]
=
0
self
.
state_mem
[
'cuda'
][
state
]
=
0
def
register_new_instance
(
self
):
self
.
_cnter
+=
1
def
delete_instance
(
self
):
self
.
_cnter
-=
1
def
print_info
(
self
):
print
(
f
"Total number:
{
self
.
total_number
}
"
,
f
"Total CPU memory occupation:
{
self
.
total_mem
[
'cpu'
]
}
"
,
f
"Total CUDA memory occupation:
{
self
.
total_mem
[
'cuda'
]
}
\n
"
,
sep
=
'
\n
'
)
for
state
in
self
.
states_cls
:
print
(
f
"
{
state
}
: CPU memory occupation:
{
self
.
state_mem
[
'cpu'
][
state
]
}
"
,
f
"
{
state
}
: CUDA memory occupation:
{
self
.
state_mem
[
'cuda'
][
state
]
}
\n
"
,
sep
=
'
\n
'
)
colossalai/gemini/gemini_mgr.py
0 → 100644
View file @
08f2920e
import
functools
from
time
import
time
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
from
colossalai.gemini.memory_tracer
import
MemStats
from
.memory_tracer
import
ChunkMemStatsCollector
from
.placement_policy
import
PlacementPolicyFactory
class
GeminiManager
:
"""
Stateful Tensor Manager, inspired from PatrickStar
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
Args:
placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'.
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
"""
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
assert
placement_policy
in
PlacementPolicyFactory
.
get_polocy_names
()
self
.
policy_name
=
placement_policy
policy_cls
=
PlacementPolicyFactory
.
create
(
placement_policy
)
self
.
_chunk_manager
=
chunk_manager
self
.
_premade_memstats_
=
memstats
is
not
None
self
.
_memstats
=
memstats
self
.
_mem_stats_collector
=
ChunkMemStatsCollector
(
chunk_manager
,
self
.
_memstats
)
if
policy_cls
.
need_mem_stats
else
None
self
.
_placement_policy
=
policy_cls
(
chunk_manager
,
self
.
_mem_stats_collector
)
self
.
_compute_list
:
List
[
Tuple
[
Chunk
,
...]]
=
[]
self
.
_compute_idx
:
int
=
-
1
self
.
_h2d_volume
=
0
self
.
_d2h_volume
=
0
self
.
_layout_time
=
0
self
.
_evict_time
=
0
self
.
_warmup
=
True
self
.
_comp_cuda_demand_time
=
0
def
memstats
(
self
):
"""memstats
get the memory statistics during training.
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
Note, for the latter, you can not access the memstats before warmup iteration finishes.
"""
if
self
.
_premade_memstats_
:
return
self
.
_memstats
else
:
assert
not
self
.
_warmup
,
"Gemini Manager has memstats after warm up! Now is during warmup."
return
self
.
_mem_stats_collector
.
_memstats
def
pre_iter
(
self
,
*
args
):
if
self
.
_mem_stats_collector
and
self
.
_warmup
:
self
.
_mem_stats_collector
.
start_collection
()
def
post_iter
(
self
):
"""This function must be called when each iteration finishes
"""
if
self
.
_mem_stats_collector
and
self
.
_warmup
:
self
.
_mem_stats_collector
.
finish_collection
()
self
.
_warmup
=
False
self
.
_compute_idx
=
-
1
self
.
_h2d_volume
=
0
self
.
_d2h_volume
=
0
self
.
_layout_time
=
0
self
.
_evict_time
=
0
self
.
_comp_cuda_demand_time
=
0
def
adjust_layout
(
self
,
chunks
:
Tuple
[
Chunk
,
...])
->
None
:
""" Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
start
=
time
()
self
.
_record_chunks_order
(
chunks
)
cuda_demand
,
hold_cuda_tensor_list
=
self
.
_get_layout_info
(
self
.
_compute_idx
,
self
.
_warmup
,
chunks
)
self
.
_layout_time
+=
time
()
-
start
vol
,
evict_time
=
self
.
_placement_policy
.
evict_tensors
(
can_evict_chunks
=
hold_cuda_tensor_list
,
cuda_demand
=
cuda_demand
,
warmup
=
self
.
_warmup
,
compute_list
=
self
.
_compute_list
,
compute_idx
=
self
.
_compute_idx
)
self
.
_d2h_volume
+=
vol
self
.
_evict_time
+=
evict_time
# move COMPUTE tensors to CUDA
self
.
_h2d_volume
+=
cuda_demand
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_layout_info
(
self
,
compute_idx
:
int
,
warmup
:
bool
,
chunks
:
Tuple
[
Chunk
,
...]):
start
=
time
()
cuda_demand
=
0
for
chunk
in
chunks
:
if
chunk
.
device_type
==
'cuda'
:
if
chunk
.
is_gathered
:
pass
else
:
cuda_demand
+=
chunk
.
chunk_mem
-
chunk
.
shard_mem
elif
chunk
.
device_type
==
'cpu'
:
cuda_demand
+=
chunk
.
chunk_mem
else
:
raise
RuntimeError
self
.
_comp_cuda_demand_time
+=
time
()
-
start
can_evict_chunks
=
self
.
_chunk_manager
.
get_cuda_movable_chunks
()
return
cuda_demand
,
can_evict_chunks
def
_record_chunks_order
(
self
,
chunks
:
Tuple
[
Chunk
,
...])
->
None
:
self
.
_compute_idx
+=
1
if
self
.
_warmup
and
self
.
_placement_policy
.
need_mem_stats
:
self
.
_compute_list
.
append
(
chunks
)
@
property
def
default_device
(
self
):
return
self
.
_placement_policy
.
get_default_device
()
def
sample_overall_data
(
self
):
if
self
.
_mem_stats_collector
:
self
.
_mem_stats_collector
.
sample_overall_data
()
def
record_model_data_volume
(
self
):
if
self
.
_mem_stats_collector
:
self
.
_mem_stats_collector
.
record_model_data_volume
()
@
property
def
chunk_manager
(
self
):
return
self
.
_chunk_manager
@
property
def
cuda_margin_mem
(
self
)
->
Optional
[
float
]:
if
self
.
_mem_stats_collector
:
return
self
.
_mem_stats_collector
.
cuda_margin_mem
return
None
@
property
def
is_cuda_margin_mem_avail
(
self
)
->
bool
:
return
self
.
_placement_policy
.
need_mem_stats
@
staticmethod
def
get_default_device
(
policy_name
:
str
)
->
torch
.
device
:
return
PlacementPolicyFactory
.
get_default_device
(
policy_name
)
colossalai/gemini/memory_tracer/__init__.py
0 → 100644
View file @
08f2920e
from
.param_runtime_order
import
OrderedParamGenerator
# isort:skip
from
.memory_stats
import
MemStats
# isort:skip
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
# isort:skip
from
.memstats_collector
import
MemStatsCollector
# isort:skip
from
.chunk_memstats_collector
import
ChunkMemStatsCollector
# isort:skip
from
.static_memstats_collector
import
StaticMemStatsCollector
# isort:skip
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'ChunkMemStatsCollector'
,
'StaticMemStatsCollector'
,
'MemStats'
,
'OrderedParamGenerator'
]
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
0 → 100644
View file @
08f2920e
from
typing
import
Optional
from
colossalai.gemini.chunk
import
ChunkManager
from
colossalai.gemini.memory_tracer
import
MemStats
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
.memstats_collector
import
MemStatsCollector
class
ChunkMemStatsCollector
(
MemStatsCollector
):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
"""
Memory Statistic Collector for Chunks.
Args:
chunk_manager (ChunkManager): the chunk manager.
memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None.
"""
super
().
__init__
(
memstats
)
self
.
_chunk_manager
=
chunk_manager
# override
def
record_model_data_volume
(
self
)
->
None
:
"""
record model data volumn on cuda and cpu.
"""
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
cuda_mem
=
self
.
_chunk_manager
.
total_mem
[
'cuda'
]
self
.
_memstats
.
record_max_cuda_model_data
(
cuda_mem
)
@
property
def
cuda_margin_mem
(
self
)
->
float
:
return
colo_device_memory_capacity
(
get_current_device
())
-
self
.
_memstats
.
max_overall_cuda
(
'cuda'
)
colossalai/gemini/memory_tracer/memory_monitor.py
0 → 100644
View file @
08f2920e
import
json
from
abc
import
abstractmethod
from
concurrent.futures
import
ThreadPoolExecutor
from
time
import
sleep
,
time
import
torch
from
colossalai.utils
import
colo_device_memory_used
,
get_current_device
class
MemoryMonitor
:
"""Base class for all types of memory monitor.
All monitors should have a list called `time_stamps` and a list called `mem_stats`.
"""
def
__init__
(
self
):
self
.
time_stamps
=
[]
self
.
mem_stats
=
[]
def
__len__
(
self
):
return
len
(
self
.
mem_stats
)
@
abstractmethod
def
start
(
self
):
pass
@
abstractmethod
def
finish
(
self
):
pass
def
state_dict
(
self
):
return
{
"time_stamps"
:
self
.
time_stamps
,
"mem_stats"
:
self
.
mem_stats
,
}
def
save
(
self
,
filename
):
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
self
.
state_dict
(),
f
)
def
clear
(
self
):
self
.
mem_stats
.
clear
()
self
.
time_stamps
.
clear
()
class
AsyncMemoryMonitor
(
MemoryMonitor
):
"""
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
Usage::
async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
output = OP1(input)
async_mem_monitor.finish()
async_mem_monitor.start()
output = OP2(output)
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')
Args:
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
"""
def
__init__
(
self
,
power
:
int
=
10
):
super
().
__init__
()
self
.
keep_measuring
=
False
current_device
=
get_current_device
()
def
_set_cuda_device
():
torch
.
cuda
.
set_device
(
current_device
)
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
1
,
initializer
=
_set_cuda_device
)
self
.
monitor_thread
=
None
self
.
interval
=
1
/
(
10
**
power
)
def
set_interval
(
self
,
power
:
int
):
self
.
clear
()
self
.
interval
=
1
/
(
10
**
power
)
def
is_measuring
(
self
):
return
self
.
keep_measuring
def
start
(
self
):
self
.
keep_measuring
=
True
self
.
monitor_thread
=
self
.
executor
.
submit
(
self
.
_measure_usage
)
def
finish
(
self
):
if
self
.
keep_measuring
is
False
:
return
0
self
.
keep_measuring
=
False
max_usage
=
self
.
monitor_thread
.
result
()
self
.
monitor_thread
=
None
self
.
time_stamps
.
append
(
time
())
self
.
mem_stats
.
append
(
max_usage
)
return
max_usage
def
_measure_usage
(
self
):
max_usage
=
0
while
self
.
keep_measuring
:
max_usage
=
max
(
max_usage
,
colo_device_memory_used
(
get_current_device
()),
)
sleep
(
self
.
interval
)
return
max_usage
class
SyncCudaMemoryMonitor
(
MemoryMonitor
):
"""
A synchronized cuda memory monitor.
It only record the maximum allocated cuda memory from start point to finish point.
"""
def
__init__
(
self
,
power
:
int
=
10
):
super
().
__init__
()
def
start
(
self
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
reset_peak_memory_stats
()
def
finish
(
self
)
->
int
:
"""
return max gpu memory used since latest `start()`.
Returns:
int: max GPU memory
"""
torch
.
cuda
.
synchronize
()
self
.
time_stamps
.
append
(
time
())
max_usage
=
torch
.
cuda
.
max_memory_allocated
()
self
.
mem_stats
.
append
(
max_usage
)
return
max_usage
colossalai/gemini/memory_tracer/memory_stats.py
0 → 100644
View file @
08f2920e
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
colossalai.gemini.memory_tracer
import
OrderedParamGenerator
class
MemStats
(
object
):
def
__init__
(
self
)
->
None
:
"""
Store the non model data statistics used for Gemini and ZeroOptimizer.
"""
# (preop_step, List[param])
self
.
_step_param_dict
=
dict
()
# (param, List[preop_step])
self
.
_param_step_dict
=
dict
()
# (preop_step, non_model_data) non model data used during preop_step ~ (preop_step+1)
self
.
_step_nmd_dict
=
dict
()
self
.
_param_runtime_order
=
OrderedParamGenerator
()
self
.
_preop_step
=
0
self
.
_prev_overall_cuda
=
-
1
self
.
_max_overall_cuda
=
0
self
.
_prev_md_cuda
=
-
1
# old version
self
.
_model_data_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
def
calc_max_cuda_non_model_data
(
self
):
if
self
.
_prev_overall_cuda
!=
-
1
and
self
.
_prev_md_cuda
!=
-
1
:
max_cuda_non_model_data
=
self
.
_prev_overall_cuda
-
self
.
_prev_md_cuda
self
.
_step_nmd_dict
[
self
.
_preop_step
-
1
]
=
max_cuda_non_model_data
# compatibility of the old version.
self
.
_non_model_data_cuda_list
.
append
(
max_cuda_non_model_data
)
def
record_max_cuda_model_data
(
self
,
val
):
self
.
_prev_md_cuda
=
val
def
record_max_cuda_overall_data
(
self
,
val
):
self
.
_prev_overall_cuda
=
val
self
.
_max_overall_cuda
=
max
(
self
.
_max_overall_cuda
,
val
)
@
property
def
max_overall_cuda
(
self
):
return
self
.
_max_overall_cuda
def
increase_preop_step
(
self
,
param_list
:
List
[
torch
.
nn
.
Parameter
]):
"""
the time step is increased. param list is used between current and the next
time step.
Args:
param_list (List[torch.nn.Parameter]): a list of torch paramters.
"""
for
p
in
param_list
:
if
p
not
in
self
.
_param_step_dict
:
self
.
_param_step_dict
[
p
]
=
[
self
.
_preop_step
]
else
:
self
.
_param_step_dict
[
p
].
append
(
self
.
_preop_step
)
self
.
_param_runtime_order
.
append
(
p
)
self
.
_step_param_dict
[
self
.
_preop_step
]
=
param_list
self
.
_preop_step
+=
1
def
param_used_step
(
self
,
param
:
torch
.
nn
.
Parameter
)
->
Optional
[
List
[
int
]]:
"""param_used_step
get the timestep list using the param
Args:
param (torch.nn.Parameter): a torch param
Returns:
Optional[List[int]]: a list of int indicates the time step of preop hook.
"""
if
param
not
in
self
.
_param_step_dict
:
return
None
else
:
return
self
.
_param_step_dict
[
param
]
def
param_order
(
self
):
if
self
.
_param_runtime_order
.
is_empty
():
raise
RuntimeError
else
:
return
self
.
_param_runtime_order
def
non_model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
self
.
_non_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
self
.
_non_model_data_cpu_list
else
:
raise
TypeError
def
max_non_model_data
(
self
,
device_type
:
str
)
->
float
:
if
device_type
==
'cuda'
:
return
max
(
self
.
_non_model_data_cuda_list
)
elif
device_type
==
'cpu'
:
return
max
(
self
.
_non_model_data_cpu_list
)
else
:
raise
TypeError
def
max_overall_cuda
(
self
,
device_type
:
str
)
->
float
:
if
device_type
==
'cuda'
:
return
max
(
self
.
_overall_cuda_list
)
elif
device_type
==
'cpu'
:
return
max
(
self
.
_overall_cpu_list
)
else
:
raise
TypeError
def
clear
(
self
):
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_param_runtime_order
.
clear
()
self
.
_step_param_dict
.
clear
()
self
.
_param_step_dict
.
clear
()
self
.
_step_nmd_dict
.
clear
()
self
.
_preop_step
=
0
self
.
_prev_overall_cuda
=
-
1
self
.
_prev_md_cuda
=
-
1
colossalai/gemini/memory_tracer/memstats_collector.py
0 → 100644
View file @
08f2920e
import
time
from
typing
import
List
,
Optional
import
torch
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.utils.memory
import
colo_device_memory_used
from
.memory_stats
import
MemStats
class
MemStatsCollector
:
"""
A Memory statistic collector.
It works in two phases.
Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU.
The first iteration of DNN training.
Phase 2. Runtime Phase: use the read-only collected stats
The rest iterations of DNN training.
It has a Sampling counter which is reset after DNN training iteration.
"""
def
__init__
(
self
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
self
.
_mem_monitor
=
SyncCudaMemoryMonitor
()
self
.
_sampling_time
=
[]
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_total
=
0
if
memstats
is
not
None
:
self
.
use_outside_memstats
=
True
self
.
_memstats
=
memstats
else
:
self
.
use_outside_memstats
=
False
self
.
_memstats
=
MemStats
()
def
next_period_non_model_data_usage
(
self
,
device_type
:
str
)
->
int
:
"""Maximum non model data memory usage during the next Op run
Args:
device_type (str): device type, can be 'cpu' or 'cuda'.
Returns:
int: max non model data memory usage of current sampling period
"""
assert
not
self
.
_start_flag
,
'Cannot get mem stats info during collection phase.'
assert
self
.
_step_total
>
0
,
'Cannot get mem stats info before collection phase.'
assert
len
(
self
.
_memstats
.
non_model_data_list
(
device_type
))
>
self
.
_step_idx
,
\
f
"
{
len
(
self
.
_memstats
.
non_model_data_list
(
device_type
))
}
should be > than step idx
{
self
.
_step_idx
}
, "
\
f
"step total
{
self
.
_step_total
}
"
next_non_model_data
=
self
.
_memstats
.
non_model_data_list
(
device_type
)[
self
.
_step_idx
]
self
.
_step_idx
=
(
self
.
_step_idx
+
1
)
%
self
.
_step_total
return
next_non_model_data
@
property
def
sampling_time
(
self
):
return
[
t
-
self
.
_sampling_time
[
0
]
for
t
in
self
.
_sampling_time
]
def
start_collection
(
self
):
print
(
'start collection'
)
self
.
_start_flag
=
True
self
.
_mem_monitor
.
start
()
def
finish_collection
(
self
):
self
.
sample_overall_data
()
# self._step_total = len(self._sampling_time)
self
.
_step_total
=
len
(
self
.
_memstats
.
non_model_data_list
(
'cuda'
))
self
.
_start_flag
=
False
self
.
_mem_monitor
.
finish
()
print
(
f
'finish_collection
{
self
.
_step_total
}
'
)
# deprecated
def
record_model_data_volume
(
self
)
->
None
:
"""
Sampling model data statistics.
"""
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cuda'
]
cpu_mem
=
StatefulTensor
.
GST_MGR
.
total_mem
[
'cpu'
]
self
.
_memstats
.
append_model_data
(
'cuda'
,
cuda_mem
)
self
.
_memstats
.
append_model_data
(
'cpu'
,
cpu_mem
)
def
sample_overall_data
(
self
)
->
None
:
"""
Sampling overall and non model data cuda memory statistics.
"""
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
cuda_overall
=
self
.
_mem_monitor
.
finish
()
self
.
_memstats
.
record_max_cuda_overall_data
(
cuda_overall
)
self
.
_memstats
.
calc_max_cuda_non_model_data
()
self
.
_mem_monitor
.
start
()
if
self
.
_start_flag
:
self
.
_sampling_time
.
append
(
time
.
time
())
def
clear
(
self
)
->
None
:
self
.
_memstats
.
clear
()
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_total
=
0
colossalai/gemini/memory_tracer/param_runtime_order.py
0 → 100644
View file @
08f2920e
from
abc
import
ABC
import
torch
class
ParamGenerator
(
ABC
):
def
append
(
self
,
param
:
torch
.
nn
.
Parameter
):
pass
def
generate
(
self
):
pass
def
clear
(
self
):
pass
class
OrderedParamGenerator
(
ParamGenerator
):
"""OrderedParamGenerator
Contain the order of parameters visited during runtime.
"""
def
__init__
(
self
)
->
None
:
self
.
param_visited_order
=
[]
def
append
(
self
,
param
:
torch
.
nn
.
Parameter
):
self
.
param_visited_order
.
append
(
param
)
def
generate
(
self
):
visited_set
=
set
()
for
p
in
self
.
param_visited_order
:
if
p
not
in
visited_set
:
yield
p
visited_set
.
add
(
p
)
del
visited_set
def
is_empty
(
self
):
return
len
(
self
.
param_visited_order
)
==
0
def
clear
(
self
):
self
.
param_visited_order
=
[]
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
0 → 100644
View file @
08f2920e
import
torch.nn
from
colossalai.gemini.memory_tracer
import
MemStats
from
colossalai.gemini.ophooks.runtime_mem_tracer_hook
import
GradMemStats
,
GradMemTracerHook
,
ParamMemTracerHook
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.tensor.param_op_hook
import
ColoParamOpHookManager
__all__
=
[
'RuntimeMemTracer'
]
class
RuntimeMemTracer
():
"""RuntimeMemTracer for the module training using ColoParameter.
Trace non-model memory usage during fwd+bwd process.
It is obtained by using a tensor with the same shape as the training process as the inputs
and running an single fwd+bwd to trace the statistics.
NOTE()
1. The premise to use this tracer is that the target DNN execute the same operations at each iterations,
2. Module buffers are viewed as non-model data.
"""
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
dtype
:
torch
.
dtype
=
torch
.
half
):
super
().
__init__
()
self
.
module
=
module
self
.
dtype
=
dtype
self
.
_gradstat
=
GradMemStats
()
self
.
_memstats
=
MemStats
()
self
.
param_op_hook
=
ParamMemTracerHook
(
self
.
_memstats
,
self
.
_gradstat
)
self
.
grad_hook
=
GradMemTracerHook
(
self
.
_gradstat
)
self
.
cpu_param_data_dict
=
{}
for
p
in
module
.
parameters
():
p
.
data
=
p
.
data
.
to
(
dtype
)
self
.
_cast_buffers_to_cuda_dtype
()
def
parameters_in_runtime_order
(
self
):
return
self
.
_memstats
.
_param_runtime_order
.
generate
()
def
memstats
(
self
):
return
self
.
_memstats
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
forward
(
*
args
,
**
kwargs
)
def
_backup_params
(
self
):
"""
The function is called before forward. Backup model params on cpu.
"""
for
p
in
self
.
module
.
parameters
():
self
.
cpu_param_data_dict
[
p
]
=
torch
.
empty
(
p
.
data
.
shape
,
dtype
=
self
.
dtype
,
device
=
"cpu"
)
self
.
cpu_param_data_dict
[
p
].
copy_
(
p
.
data
)
def
_restore_params
(
self
):
"""
This function is called after backward. Restore model params.
"""
for
p
in
self
.
module
.
parameters
():
p
.
data
=
torch
.
empty
(
p
.
data
.
shape
,
dtype
=
self
.
dtype
,
device
=
"cpu"
,
requires_grad
=
p
.
data
.
requires_grad
)
p
.
data
.
copy_
(
self
.
cpu_param_data_dict
[
p
])
self
.
cpu_param_data_dict
.
clear
()
def
_pre_forward
(
self
):
self
.
_clear_cuda_mem_info
()
self
.
_backup_params
()
self
.
grad_hook
.
register_grad_hook
(
self
.
module
)
self
.
param_op_hook
.
mem_monitor
.
start
()
def
forward
(
self
,
*
args
,
**
kwargs
):
args
,
kwargs
=
_cast_float
(
args
,
self
.
dtype
),
_cast_float
(
kwargs
,
self
.
dtype
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
_pre_forward
()
with
ColoParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
return
outputs
def
backward
(
self
,
loss
):
with
self
.
param_op_hook
.
switch_to_backward
(),
ColoParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
loss
.
backward
()
self
.
_post_backward
()
def
_post_backward
(
self
):
cuda_volume
=
self
.
param_op_hook
.
mem_monitor
.
finish
()
self
.
_memstats
.
record_max_cuda_overall_data
(
cuda_volume
)
# calc the last Op non model data
self
.
_memstats
.
calc_max_cuda_non_model_data
()
self
.
grad_hook
.
remove_grad_hook
()
self
.
_restore_params
()
def
_clear_cuda_mem_info
(
self
):
self
.
_memstats
.
clear
()
self
.
_gradstat
.
clear
()
def
_cast_buffers_to_cuda_dtype
(
self
):
for
buffer
in
self
.
module
.
buffers
():
buffer
.
data
=
buffer
.
cuda
()
if
torch
.
is_floating_point
(
buffer
):
buffer
.
data
=
buffer
.
data
.
to
(
self
.
dtype
)
colossalai/gemini/memory_tracer/static_memstats_collector.py
0 → 100644
View file @
08f2920e
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
torch.fx
import
symbolic_trace
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
calculate_fwd_out
,
calculate_fwd_tmp
,
is_compatible_with_meta
from
colossalai.gemini.chunk
import
ChunkManager
if
is_compatible_with_meta
():
from
colossalai.fx.profiler
import
MetaTensor
from
.chunk_memstats_collector
import
ChunkMemStatsCollector
class
ModuleInfos
:
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
module_name
:
str
,
module_full_name
:
str
,
parent_module
:
torch
.
nn
.
Module
):
self
.
module
=
module
self
.
module_name
=
module_name
self
.
module_full_name
=
module_full_name
self
.
parent_module
=
parent_module
class
StaticMemStatsCollector
(
ChunkMemStatsCollector
):
"""
A Static Memory statistic collector.
"""
def
__init__
(
self
,
module
:
nn
.
Module
,
chunk_manager
:
ChunkManager
)
->
None
:
super
().
__init__
(
chunk_manager
)
self
.
module
=
module
self
.
module_info_list
=
[]
def
init_mem_stats
(
self
,
*
inputs
):
self
.
register_opnodes_recursively
(
self
.
module
)
self
.
refactor_module
()
self
.
module
=
self
.
module
.
cpu
()
self
.
module
.
train
()
data
=
[
MetaTensor
(
torch
.
rand
(
inp
.
shape
,
device
=
'meta'
),
fake_device
=
'cpu'
)
for
inp
in
inputs
]
gm
=
symbolic_trace
(
self
.
module
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
*
data
)
total_mem
=
0
for
inp
in
inputs
:
total_mem
+=
inp
.
numel
()
*
inp
.
element_size
()
last_node
=
None
module_name_list
=
[
mInfo
.
module_full_name
for
mInfo
in
self
.
module_info_list
]
for
node
in
gm
.
graph
.
nodes
:
total_mem
=
total_mem
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
last_node
=
node
self
.
_non_model_data_cuda_list
.
append
(
total_mem
)
self
.
_non_model_data_cuda_list
=
self
.
_non_model_data_cuda_list
[
1
:]
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
last_node
.
meta
[
"fwd_mem_out"
]
for
node
in
gm
.
graph
.
nodes
.
__reversed__
():
cur_module_mem_fwd
=
cur_module_mem_fwd
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
cur_module_mem_bwd
=
cur_module_mem_bwd
+
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
if
node
.
op
==
"call_module"
:
if
node
.
name
.
endswith
(
"_0"
)
and
node
.
name
[:
-
2
]
in
module_name_list
:
self
.
_non_model_data_cuda_list
.
append
(
total_mem
+
grad_module_out
+
cur_module_mem_bwd
)
total_mem
=
total_mem
-
cur_module_mem_fwd
cur_module_mem_fwd
=
0
cur_module_mem_bwd
=
0
grad_module_out
=
node
.
meta
[
"bwd_mem_out"
]
self
.
_step_total
=
len
(
self
.
_non_model_data_cuda_list
)
self
.
recover_module
()
def
refactor_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
temp_node
=
nn
.
Sequential
(
nn
.
ReLU
(),
modInfo
.
module
)
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
temp_node
)
def
recover_module
(
self
):
for
modInfo
in
self
.
module_info_list
:
modInfo
.
parent_module
.
__setattr__
(
modInfo
.
module_name
,
modInfo
.
module
)
def
register_opnodes_recursively
(
self
,
module
:
torch
.
nn
.
Module
,
name
:
str
=
""
,
full_name
:
str
=
""
,
parent_module
:
Optional
[
torch
.
nn
.
Module
]
=
None
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
for
child_name
,
child
in
module
.
named_children
():
self
.
register_opnodes_recursively
(
child
,
child_name
,
full_name
+
"_"
+
child_name
,
module
)
# Early return on modules with no parameters.
if
len
(
list
(
module
.
parameters
(
recurse
=
False
)))
==
0
:
return
self
.
module_info_list
.
append
(
ModuleInfos
(
module
,
name
,
full_name
[
1
:],
parent_module
))
colossalai/gemini/memory_tracer/utils.py
0 → 100644
View file @
08f2920e
from
typing
import
Optional
,
Tuple
import
torch
def
colo_model_optimizer_usage
(
optim
)
->
Tuple
[
int
,
int
]:
"""Trace the optimizer memory usage
Args:
optim (ShardedOptimV2): an instance of ShardedOptimver
Returns:
Tuple[int, int]: cuda/cpu memory usage in Byte
"""
if
optim
is
None
:
return
0
,
0
assert
hasattr
(
optim
,
'get_memory_usage'
),
f
"
{
type
(
optim
)
}
has no attr get_memory_usage()"
return
optim
.
get_memory_usage
()
def
colo_model_mem_usage
(
model
:
torch
.
nn
.
Module
)
->
Tuple
[
int
,
int
]:
"""
Trace the model memory usage.
Args:
model (torch.nn.Module): a torch model
Returns:
Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte
"""
if
model
is
None
:
return
0
,
0
def
_get_tensor_mem_use
(
t
:
Optional
[
torch
.
Tensor
]):
if
t
is
None
:
return
0
,
0
assert
isinstance
(
t
,
torch
.
Tensor
)
_cpu_mem_usage
,
_cuda_mem_usage
=
0
,
0
if
t
.
device
.
type
==
'cpu'
:
_cpu_mem_usage
+=
t
.
numel
()
*
t
.
element_size
()
elif
t
.
device
.
type
==
'cuda'
:
_cuda_mem_usage
+=
t
.
numel
()
*
t
.
element_size
()
return
_cuda_mem_usage
,
_cpu_mem_usage
cuda_mem_usage
=
0
cpu_mem_usage
=
0
for
param
in
model
.
parameters
():
if
hasattr
(
param
,
'colo_attr'
):
t_cuda
,
t_cpu
=
param
.
colo_attr
.
get_memory_usage
()
cuda_mem_usage
+=
t_cuda
cpu_mem_usage
+=
t_cpu
else
:
t_cuda
,
t_cpu
=
_get_tensor_mem_use
(
param
.
data
)
cuda_mem_usage
+=
t_cuda
cpu_mem_usage
+=
t_cpu
t_cuda
,
t_cpu
=
_get_tensor_mem_use
(
param
.
grad
)
cuda_mem_usage
+=
t_cuda
cpu_mem_usage
+=
t_cpu
return
cuda_mem_usage
,
cpu_mem_usage
colossalai/gemini/ophooks/__init__.py
0 → 100644
View file @
08f2920e
from
.utils
import
BaseOpHook
,
register_ophooks_recursively
__all__
=
[
"BaseOpHook"
,
"register_ophooks_recursively"
]
colossalai/gemini/ophooks/_shard_grad_ophook.py
0 → 100644
View file @
08f2920e
import
torch
from
colossalai.registry
import
OPHOOKS
from
.
import
BaseOpHook
@
OPHOOKS
.
register_module
class
ShardGradMemTracerHook
(
BaseOpHook
):
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
def
__init__
(
self
):
super
().
__init__
()
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
pass
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
pass
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'_sharded_grad'
)
param
.
_sharded_grad
.
setup
()
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
pass
def
post_iter
(
self
):
pass
colossalai/gemini/ophooks/_shard_param_ophook.py
0 → 100644
View file @
08f2920e
import
torch
from
colossalai.registry
import
OPHOOKS
from
.
import
BaseOpHook
@
OPHOOKS
.
register_module
class
ShardParamHook
(
BaseOpHook
):
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
def
__init__
(
self
):
super
().
__init__
()
def
niter
(
self
):
return
self
.
_niter
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
gather
()
param
.
data
=
param
.
ca_attr
.
payload
()
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
shard
()
param
.
data
=
param
.
ca_attr
.
payload
()
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
gather
()
param
.
data
=
param
.
ca_attr
.
payload
()
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
shard
()
param
.
data
=
param
.
ca_attr
.
payload
()
def
pre_iter
(
self
):
pass
def
post_iter
(
self
):
pass
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
0 → 100644
View file @
08f2920e
from
contextlib
import
contextmanager
from
enum
import
Enum
from
functools
import
partial
from
typing
import
List
import
torch
from
colossalai.gemini.memory_tracer
import
MemStats
,
SyncCudaMemoryMonitor
from
colossalai.gemini.tensor_utils
import
alloc_storage
,
free_storage
from
colossalai.tensor.param_op_hook
import
ColoParamOpHook
class
TrainingPhase
(
Enum
):
FORWARD
=
0
BACKWARD
=
1
class
GradMemStats
():
def
__init__
(
self
)
->
None
:
self
.
unreleased_grad_flag
=
{}
self
.
unreleased_grad_volume
=
0
def
clear
(
self
):
self
.
unreleased_grad_flag
.
clear
()
self
.
unreleased_grad_volume
=
0
class
GradMemTracerHook
():
def
__init__
(
self
,
grad_stats
:
GradMemStats
):
self
.
grad_hook_list
=
[]
self
.
_grad_stats
=
grad_stats
def
grad_handle
(
self
,
p
,
grad
):
assert
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]
free_storage
(
grad
)
self
.
_grad_stats
.
unreleased_grad_volume
-=
grad
.
numel
()
*
grad
.
element_size
()
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]
=
False
def
register_grad_hook
(
self
,
module
:
torch
.
nn
.
Module
):
for
p
in
module
.
parameters
():
if
p
.
requires_grad
:
self
.
grad_hook_list
.
append
(
p
.
register_hook
(
partial
(
self
.
grad_handle
,
p
)))
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]
=
False
def
remove_grad_hook
(
self
):
for
hook
in
self
.
grad_hook_list
:
hook
.
remove
()
class
ParamMemTracerHook
(
ColoParamOpHook
):
def
__init__
(
self
,
memstats
:
MemStats
,
gradstats
:
GradMemStats
)
->
None
:
super
().
__init__
()
self
.
_training_phase
=
TrainingPhase
.
FORWARD
self
.
_memstats
=
memstats
self
.
_grad_stats
=
gradstats
self
.
mem_monitor
=
SyncCudaMemoryMonitor
()
def
_free_cuda_params
(
self
,
params
):
for
p
in
params
:
if
p
.
data
.
device
.
type
==
"cpu"
:
raise
NotImplementedError
(
"Only free cuda memory"
)
free_storage
(
p
.
data
)
def
_allocate_params_on_cuda
(
self
,
params
:
List
[
torch
.
nn
.
Parameter
]):
"""
move params to cuda
Args:
params (List[torch.nn.Parameter]): target params
Raises:
NotImplementedError: raise error when param has cpu grad
"""
for
p
in
params
:
cur_dev
=
p
.
data
.
device
.
type
if
cur_dev
==
"cpu"
:
if
p
.
grad
is
not
None
and
p
.
grad
.
device
.
type
==
"cpu"
:
raise
NotImplementedError
(
"Only run in forward propagation"
)
p
.
data
=
torch
.
empty
(
p
.
data
.
shape
,
device
=
"cuda"
,
dtype
=
p
.
data
.
dtype
,
requires_grad
=
p
.
data
.
requires_grad
)
elif
cur_dev
==
"cuda"
:
alloc_storage
(
p
.
data
)
def
record_model_data_volume
(
self
,
params
):
"""
get cuda model data used by params
"""
data_volume
=
self
.
_grad_stats
.
unreleased_grad_volume
for
p
in
params
:
cur_model_data_volume
=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
data_volume
+=
cur_model_data_volume
if
self
.
_training_phase
==
TrainingPhase
.
BACKWARD
and
p
.
requires_grad
:
# add param.grad, actually param.grad is None in this time
data_volume
+=
cur_model_data_volume
if
not
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]:
self
.
_grad_stats
.
unreleased_grad_volume
+=
cur_model_data_volume
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]
=
True
# record max non model data used for this Op
self
.
_memstats
.
record_max_cuda_model_data
(
data_volume
)
def
pre_op
(
self
,
params
):
max_cuda_used_pre_op
=
self
.
mem_monitor
.
finish
()
# record max cuda overall data for prev OP.
self
.
_memstats
.
record_max_cuda_overall_data
(
max_cuda_used_pre_op
)
# record max cuda non model data for prev OP.
self
.
_memstats
.
calc_max_cuda_non_model_data
()
self
.
_allocate_params_on_cuda
(
params
)
# record max cuda model data for current OP
self
.
record_model_data_volume
(
params
)
self
.
mem_monitor
.
start
()
self
.
_memstats
.
increase_preop_step
(
params
)
def
post_op
(
self
,
params
):
self
.
_free_cuda_params
(
params
)
def
pre_forward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
pre_op
(
params
)
def
post_forward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
post_op
(
params
)
def
pre_backward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
pre_op
(
params
)
def
post_backward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
post_op
(
params
)
@
contextmanager
def
switch_training_phase
(
self
,
training_phase
:
TrainingPhase
=
TrainingPhase
.
BACKWARD
):
old_training_phase
=
self
.
_training_phase
try
:
self
.
_training_phase
=
training_phase
yield
finally
:
self
.
_training_phase
=
old_training_phase
switch_to_backward
=
switch_training_phase
switch_to_forward
=
partial
(
switch_to_backward
,
training_phase
=
TrainingPhase
.
FORWARD
)
colossalai/gemini/ophooks/utils.py
0 → 100644
View file @
08f2920e
import
torch
from
typing
import
List
,
Callable
,
Optional
from
abc
import
ABC
,
abstractmethod
import
torch
class
BaseOpHook
(
ABC
):
"""This class allows users to add customized operations
before and after the execution of a PyTorch submodule"""
def
__init__
(
self
):
pass
@
abstractmethod
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
pass
@
abstractmethod
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
pass
@
abstractmethod
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
pass
@
abstractmethod
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
pass
@
abstractmethod
def
post_iter
(
self
):
pass
# apply torch.autograd.Function that calls a backward_function to tensors in output
def
_apply_to_tensors_only
(
module
,
functional
,
backward_function
,
outputs
):
if
type
(
outputs
)
is
tuple
:
touched_outputs
=
[]
for
output
in
outputs
:
touched_output
=
_apply_to_tensors_only
(
module
,
functional
,
backward_function
,
output
)
touched_outputs
.
append
(
touched_output
)
return
tuple
(
touched_outputs
)
elif
type
(
outputs
)
is
torch
.
Tensor
:
return
functional
.
apply
(
module
,
backward_function
,
outputs
)
else
:
return
outputs
class
PreBackwardFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
module
,
pre_backward_function
,
outputs
):
ctx
.
module
=
module
ctx
.
pre_backward_function
=
pre_backward_function
module
.
applied_pre_backward
=
False
outputs
=
outputs
.
detach
()
return
outputs
@
staticmethod
def
backward
(
ctx
,
*
args
):
ctx
.
pre_backward_function
(
ctx
.
module
)
return
(
None
,
None
)
+
args
class
PostBackwardFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
module
,
pre_backward_function
,
output
):
ctx
.
module
=
module
output
=
output
.
detach
()
ctx
.
pre_backward_function
=
pre_backward_function
return
output
@
staticmethod
def
backward
(
ctx
,
*
args
):
"""
Args:
activation_grad of the next layer.
Returns:
grad of the input activation.
"""
ctx
.
pre_backward_function
(
ctx
.
module
)
return
(
None
,
None
)
+
args
def
register_ophooks_recursively
(
module
:
torch
.
nn
.
Module
,
ophook_list
:
List
[
BaseOpHook
],
name
:
str
=
""
,
filter_fn
:
Optional
[
Callable
]
=
None
):
r
"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
assert
isinstance
(
ophook_list
,
(
list
,
tuple
))
assert
len
(
ophook_list
)
>
0
,
'expected at least 1 hook in the argument ophook_list but found 0'
for
hook
in
ophook_list
:
assert
(
isinstance
(
hook
,
BaseOpHook
))
# Add hooks for submodules
for
child_name
,
child
in
module
.
named_children
():
register_ophooks_recursively
(
child
,
ophook_list
,
name
+
child_name
,
filter_fn
)
# Early return on modules with no parameters.
if
len
(
list
(
module
.
parameters
(
recurse
=
False
)))
==
0
:
return
# return from flitered module
if
filter_fn
is
not
None
and
filter_fn
(
module
):
return
def
_pre_forward_module_hook
(
submodule
,
*
args
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
pre_fwd_exec
(
submodule
,
*
args
)
def
_post_forward_module_hook
(
submodule
,
*
args
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
post_fwd_exec
(
submodule
,
*
args
)
def
_pre_backward_module_hook
(
submodule
,
inputs
,
output
):
def
_run_before_backward_function
(
submodule
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
pre_bwd_exec
(
submodule
,
inputs
,
output
)
return
_apply_to_tensors_only
(
submodule
,
PreBackwardFunction
,
_run_before_backward_function
,
output
)
def
_post_backward_module_hook
(
submodule
,
inputs
):
def
_run_after_backward_function
(
submodule
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
post_bwd_exec
(
submodule
,
inputs
)
return
_apply_to_tensors_only
(
submodule
,
PostBackwardFunction
,
_run_after_backward_function
,
inputs
)
module
.
register_forward_pre_hook
(
_pre_forward_module_hook
)
module
.
register_forward_hook
(
_post_forward_module_hook
)
module
.
register_forward_hook
(
_pre_backward_module_hook
)
module
.
register_forward_pre_hook
(
_post_backward_module_hook
)
Prev
1
…
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment