Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5be118f4
Unverified
Commit
5be118f4
authored
Sep 24, 2022
by
HELSON
Committed by
GitHub
Sep 24, 2022
Browse files
[feature] new zero implementation (#1623)
parent
f9217336
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
456 additions
and
1313 deletions
+456
-1313
colossalai/gemini/__init__.py
colossalai/gemini/__init__.py
+2
-6
colossalai/gemini/chunk.py
colossalai/gemini/chunk.py
+0
-316
colossalai/gemini/chunk/__init__.py
colossalai/gemini/chunk/__init__.py
+3
-0
colossalai/gemini/chunk/chunk.py
colossalai/gemini/chunk/chunk.py
+156
-57
colossalai/gemini/chunk/manager.py
colossalai/gemini/chunk/manager.py
+38
-22
colossalai/gemini/chunk/search_utils.py
colossalai/gemini/chunk/search_utils.py
+14
-9
colossalai/gemini/chunk_mgr.py
colossalai/gemini/chunk_mgr.py
+0
-344
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+18
-11
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+1
-1
colossalai/gemini/placement_policy.py
colossalai/gemini/placement_policy.py
+17
-16
colossalai/gemini/stateful_tensor_container.py
colossalai/gemini/stateful_tensor_container.py
+0
-131
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+77
-53
colossalai/nn/parallel/utils.py
colossalai/nn/parallel/utils.py
+20
-0
colossalai/zero/utils/zero_hook_v2.py
colossalai/zero/utils/zero_hook_v2.py
+1
-1
colossalai/zero/zero_optimizer.py
colossalai/zero/zero_optimizer.py
+81
-171
tests/test_ddp/test_ddp_ignore_params.py
tests/test_ddp/test_ddp_ignore_params.py
+14
-10
tests/test_ddp/test_ddp_state_dict.py
tests/test_ddp/test_ddp_state_dict.py
+1
-70
tests/test_gemini/test_stateful_tensor_container.py
tests/test_gemini/test_stateful_tensor_container.py
+0
-74
tests/test_gemini/update/test_chunk_mgrv2.py
tests/test_gemini/update/test_chunk_mgrv2.py
+6
-12
tests/test_gemini/update/test_chunkv2.py
tests/test_gemini/update/test_chunkv2.py
+7
-9
No files found.
colossalai/gemini/__init__.py
View file @
5be118f4
from
.chunk
import
TensorInfo
,
Chunk
,
TensorState
from
.chunk
import
TensorInfo
,
TensorState
from
.chunk_mgr
import
ChunkManager
from
.stateful_tensor_mgr
import
StatefulTensorMgr
from
.stateful_tensor_mgr
import
StatefulTensorMgr
from
.tensor_placement_policy
import
TensorPlacementPolicyFactory
from
.tensor_placement_policy
import
TensorPlacementPolicyFactory
from
.gemini_mgr
import
GeminiManager
from
.gemini_mgr
import
GeminiManager
__all__
=
[
__all__
=
[
'StatefulTensorMgr'
,
'TensorPlacementPolicyFactory'
,
'GeminiManager'
,
'TensorInfo'
,
'TensorState'
]
'StatefulTensorMgr'
,
'TensorPlacementPolicyFactory'
,
'GeminiManager'
,
'ChunkManager'
,
'TensorInfo'
,
'Chunk'
,
'TensorState'
]
colossalai/gemini/chunk.py
deleted
100644 → 0
View file @
f9217336
import
torch
import
torch.distributed
as
dist
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Optional
,
Dict
,
List
from
colossalai.utils
import
get_current_device
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
class
TensorState
(
Enum
):
FREE
=
0
COMPUTE
=
1
HOLD
=
2
HOLD_AFTER_BWD
=
3
READY_FOR_REDUCE
=
4
STATE_TRANS
=
((
TensorState
.
FREE
,
TensorState
.
HOLD
),
(
TensorState
.
FREE
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD
,
TensorState
.
FREE
),
(
TensorState
.
HOLD
,
TensorState
.
COMPUTE
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD_AFTER_BWD
),
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
READY_FOR_REDUCE
,
TensorState
.
HOLD
))
@
dataclass
class
TensorInfo
:
state
:
TensorState
offset
:
int
end
:
int
class
ChunkFullError
(
Exception
):
pass
def
is_storage_empty
(
tensor
:
torch
.
Tensor
)
->
bool
:
return
tensor
.
storage
().
size
()
==
0
def
free_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
not
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
0
)
def
alloc_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
tensor
.
numel
())
class
Chunk
:
"""
A chunk is a contiguous memory space which contains multiple tensors.
Args:
chunk_size (int): the number of elements in a chunk
src_rank (int): the process which owns the chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
"""
def
__init__
(
self
,
chunk_size
:
int
,
src_rank
:
int
,
process_group
:
ColoProcessGroup
,
dtype
:
torch
.
dtype
,
init_device
:
Optional
[
torch
.
device
]
=
None
,
force_data_on_cuda
:
bool
=
False
)
->
None
:
self
.
size
=
chunk_size
self
.
utilized_size
=
0
self
.
src_rank
=
src_rank
self
.
process_group
=
process_group
self
.
is_src_rank
=
process_group
.
dp_local_rank
()
==
src_rank
self
.
global_src_rank
=
process_group
.
get_ranks_in_dp
()[
src_rank
]
self
.
dtype
=
dtype
device
=
init_device
or
get_current_device
()
if
force_data_on_cuda
:
self
.
data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
get_current_device
())
self
.
_cpu_data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
)
if
device
.
type
==
'cuda'
:
free_storage
(
self
.
_cpu_data
)
else
:
free_storage
(
self
.
data
)
else
:
self
.
data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
self
.
_cpu_data
=
None
# we only keep the chunk in full in the process by which the tensor is owned
if
not
self
.
is_src_rank
:
free_storage
(
self
.
_payload
)
# each tensor is associated with a TensorInfo to track meta info
self
.
tensors_info
:
Dict
[
torch
.
Tensor
,
TensorInfo
]
=
{}
self
.
mem
=
self
.
size
*
self
.
data
.
element_size
()
def
append
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
"""
Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
assert
tensor
.
dtype
==
self
.
dtype
new_utilized_size
=
self
.
utilized_size
+
tensor
.
numel
()
# raise exception when the chunk size is exceeded
if
new_utilized_size
>
self
.
size
:
raise
ChunkFullError
# set tensor state
tensor_state
=
TensorState
.
FREE
# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
if
self
.
is_src_rank
:
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
flatten
())
tensor_state
=
TensorState
.
HOLD
assert
type
(
self
.
_payload
)
==
torch
.
Tensor
,
"copy_tensor_to_chunk_slice must use a torch tensor"
tensor
.
data
=
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
else
:
tensor
.
storage
().
resize_
(
0
)
self
.
tensors_info
[
tensor
]
=
TensorInfo
(
tensor_state
,
self
.
utilized_size
,
new_utilized_size
)
self
.
utilized_size
=
new_utilized_size
def
release
(
self
)
->
None
:
"""
Release the memory space on processes which do not own the chunk.
"""
if
not
self
.
is_src_rank
:
free_storage
(
self
.
_payload
)
self
.
_update_tensors_state
(
TensorState
.
FREE
)
def
_update_tensors_ptr
(
self
)
->
None
:
assert
type
(
self
.
_payload
)
==
torch
.
Tensor
for
tensor
,
tensor_info
in
self
.
tensors_info
.
items
():
tensor
.
data
=
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
def
_update_tensors_state
(
self
,
next_state
:
TensorState
,
prev_state
:
Optional
[
TensorState
]
=
None
):
for
tensor_info
in
self
.
tensors_info
.
values
():
if
prev_state
is
None
or
tensor_info
.
state
==
prev_state
:
tensor_info
.
state
=
next_state
def
access
(
self
)
->
None
:
"""
Broadcast the chunk to synchronize the tensors across data parallel processes.
"""
# recover the chunk on non-owner processes
# and broadcast the chunk from the source to all processes
if
not
self
.
is_src_rank
:
alloc_storage
(
self
.
_payload
)
self
.
move_device
(
get_current_device
(),
update_ptr
=
False
)
dist
.
broadcast
(
self
.
data
,
self
.
global_src_rank
,
group
=
self
.
process_group
.
dp_process_group
())
# update tensor meta info
self
.
_update_tensors_ptr
()
if
not
self
.
is_src_rank
:
self
.
_update_tensors_state
(
TensorState
.
HOLD
,
prev_state
=
TensorState
.
FREE
)
def
move_device
(
self
,
device
:
torch
.
device
,
update_ptr
:
bool
=
True
)
->
None
:
"""
Move the chunk to a target device.
Args:
device (torch.device): the target device for data movement.
"""
if
self
.
_payload
.
device
==
device
:
return
if
self
.
_cpu_data
is
None
:
self
.
data
.
data
=
self
.
data
.
to
(
device
)
else
:
if
device
.
type
==
'cuda'
:
# cpu -> cuda
src
=
self
.
_cpu_data
dest
=
self
.
data
else
:
# cuda -> cpu
src
=
self
.
data
dest
=
self
.
_cpu_data
alloc_storage
(
dest
)
dest
.
copy_
(
src
)
free_storage
(
src
)
if
update_ptr
:
self
.
_update_tensors_ptr
()
def
reduce
(
self
,
is_all_reduce
:
bool
=
False
)
->
None
:
"""
Reduce or all-reduce the chunk.
Args:
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
"""
self
.
move_device
(
get_current_device
(),
update_ptr
=
False
)
if
is_all_reduce
:
dist
.
all_reduce
(
self
.
data
,
group
=
self
.
process_group
.
dp_process_group
())
else
:
dist
.
reduce
(
self
.
data
,
self
.
global_src_rank
,
group
=
self
.
process_group
.
dp_process_group
())
self
.
_update_tensors_ptr
()
self
.
_update_tensors_state
(
TensorState
.
HOLD
)
def
tensor_trans_state
(
self
,
tensor
:
torch
.
Tensor
,
tensor_state
:
TensorState
)
->
None
:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if
(
self
.
tensors_info
[
tensor
].
state
,
tensor_state
)
not
in
STATE_TRANS
:
# print(
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self
.
tensors_info
[
tensor
].
state
=
tensor_state
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data_slice
:
torch
.
Tensor
)
->
None
:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
tensor_info
=
self
.
tensors_info
[
tensor
]
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
flatten
())
tensor
.
data
=
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
@
property
def
can_release
(
self
)
->
bool
:
"""
Check whether the chunk can be released.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
!=
TensorState
.
HOLD
:
return
False
return
True
@
property
def
can_move_device
(
self
)
->
bool
:
"""
Check whether the chunk can be moved across devices.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
in
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
):
return
False
return
True
@
property
def
can_reduce
(
self
)
->
bool
:
"""
Check whether the chunk can be reduced.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
!=
TensorState
.
READY_FOR_REDUCE
:
return
False
return
True
@
property
def
is_empty
(
self
)
->
bool
:
"""
Check whether the chunk is empty.
"""
return
is_storage_empty
(
self
.
_payload
)
def
__repr__
(
self
)
->
str
:
return
f
'Chunk: src rank=
{
self
.
src_rank
}
,size=
{
self
.
size
}
, utilization=
{
self
.
utilized_size
/
self
.
size
*
100
:.
2
f
}
%, freed=
{
self
.
is_empty
}
, tensor states=
{
[
info
.
state
.
name
for
info
in
self
.
tensors_info
.
values
()]
}
'
@
property
def
has_inf_or_nan
(
self
)
->
bool
:
"""
Check if the chunk has inf or nan values.
"""
return
torch
.
isinf
(
self
.
_payload
[:
self
.
utilized_size
]).
any
().
item
()
or
\
torch
.
isnan
(
self
.
_payload
[:
self
.
utilized_size
]).
any
().
item
()
def
copy_
(
self
,
dest_chunk
:
'Chunk'
):
"""
Copy the data of this chunk to a destination chunk.
"""
assert
not
self
.
is_empty
assert
not
dest_chunk
.
is_empty
assert
self
.
size
==
dest_chunk
.
size
assert
self
.
utilized_size
==
dest_chunk
.
utilized_size
self
.
_payload
.
copy_
(
dest_chunk
.
_payload
)
self
.
_update_tensors_ptr
()
@
property
def
device_type
(
self
)
->
str
:
"""
Get the device type of the chunk.
"""
return
self
.
_payload
.
device
.
type
def
__hash__
(
self
)
->
int
:
return
hash
(
id
(
self
))
def
__eq__
(
self
,
__o
:
object
)
->
bool
:
return
self
is
__o
def
get_tensors
(
self
)
->
List
[
torch
.
Tensor
]:
return
list
(
self
.
tensors_info
.
keys
())
@
property
def
_payload
(
self
)
->
torch
.
Tensor
:
if
self
.
_cpu_data
is
None
or
is_storage_empty
(
self
.
_cpu_data
):
return
self
.
data
return
self
.
_cpu_data
colossalai/gemini/
update
/__init__.py
→
colossalai/gemini/
chunk
/__init__.py
View file @
5be118f4
from
.chunk
v2
import
Chunk
V2
from
.chunk
import
TensorState
,
TensorInfo
,
ChunkFullError
,
Chunk
from
.
chunk_mgrv2
import
ChunkManager
V2
from
.
manager
import
ChunkManager
from
.search_utils
import
clasify_params
,
search_chunk_configuration
from
.search_utils
import
clasify_params
,
search_chunk_configuration
colossalai/gemini/
update
/chunk
v2
.py
→
colossalai/gemini/
chunk
/chunk.py
View file @
5be118f4
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Optional
,
Dict
,
List
from
typing
import
Optional
,
Dict
,
List
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.gemini.chunk
import
TensorState
,
STATE_TRANS
,
TensorInfo
,
ChunkFullError
,
\
free_storage
,
alloc_storage
class
ChunkV2
:
class
TensorState
(
Enum
):
FREE
=
0
COMPUTE
=
1
HOLD
=
2
HOLD_AFTER_BWD
=
3
READY_FOR_REDUCE
=
4
STATE_TRANS
=
((
TensorState
.
FREE
,
TensorState
.
HOLD
),
(
TensorState
.
FREE
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD
,
TensorState
.
FREE
),
(
TensorState
.
HOLD
,
TensorState
.
COMPUTE
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD_AFTER_BWD
),
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
READY_FOR_REDUCE
,
TensorState
.
HOLD
))
@
dataclass
class
TensorInfo
:
state
:
TensorState
offset
:
int
end
:
int
class
ChunkFullError
(
Exception
):
pass
def
is_storage_empty
(
tensor
:
torch
.
Tensor
)
->
bool
:
return
tensor
.
storage
().
size
()
==
0
def
free_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
not
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
0
)
def
alloc_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
tensor
.
numel
())
class
Chunk
:
def
__init__
(
self
,
def
__init__
(
self
,
chunk_size
:
int
,
chunk_size
:
int
,
...
@@ -19,18 +60,18 @@ class ChunkV2:
...
@@ -19,18 +60,18 @@ class ChunkV2:
pin_memory
:
bool
=
False
)
->
None
:
pin_memory
:
bool
=
False
)
->
None
:
"""
"""
Chunk: A container owning a piece of contiguous memory space for tensors
Chunk: A container owning a piece of contiguous memory space for tensors
AgChunk is a kind of chunk, which
use
s
all-gather operation to gather the whole chunk.
Here we
use all-gather operation to gather the whole chunk.
This kind of c
hunk is exclusively used for DDP and ZeRO DDP.
Currently, C
hunk is exclusively used for DDP and ZeRO DDP
and it doesn't support unused parameters
.
It is designed to make the full use of communication and PCIE bandwidth.
It is designed to make the full use of communication and PCIE bandwidth.
Args:
Args:
chunk_size (int): the number of elements in
a
chunk
chunk_size (int): the number of elements in
the
chunk
process_group (ColoProcessGroup): the process group of this chunk
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized
init_device (torch.device): optional, the device where the tensor is initialized
The default value is None, which is the current GPU
The default value is None, which is the current GPU
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard cop
y
in pinned CPU memory
pin_memory (bool): optional, if True, this chunk always has a shard cop
ied
in pinned CPU memory
"""
"""
self
.
chunk_size
=
chunk_size
self
.
chunk_size
=
chunk_size
...
@@ -42,7 +83,8 @@ class ChunkV2:
...
@@ -42,7 +83,8 @@ class ChunkV2:
self
.
pg_rank
=
dist
.
get_rank
(
self
.
torch_pg
)
self
.
pg_rank
=
dist
.
get_rank
(
self
.
torch_pg
)
# the chunk size should be able to be divied by the size of GPU
# the chunk size should be able to be divied by the size of GPU
assert
chunk_size
%
self
.
pg_size
==
0
if
not
keep_gathered
:
assert
chunk_size
%
self
.
pg_size
==
0
self
.
shard_size
=
chunk_size
//
self
.
pg_size
self
.
shard_size
=
chunk_size
//
self
.
pg_size
self
.
shard_begin
=
self
.
shard_size
*
self
.
pg_rank
self
.
shard_begin
=
self
.
shard_size
*
self
.
pg_rank
self
.
shard_end
=
self
.
shard_begin
+
self
.
shard_size
self
.
shard_end
=
self
.
shard_begin
+
self
.
shard_size
...
@@ -80,18 +122,15 @@ class ChunkV2:
...
@@ -80,18 +122,15 @@ class ChunkV2:
# we introduce the paired chunk here
# we introduce the paired chunk here
# it refers to another chunk having the same parameters
# it refers to another chunk having the same parameters
# but with different dtype(such as fp16_chunk.
mapping
_chunk -> fp32_chunk
# but with different dtype(such as fp16_chunk.
paired
_chunk -> fp32_chunk
self
.
paired_chunk
=
None
self
.
paired_chunk
=
None
# if the the gradient of this chunk is reduced, the flag is True
# so the flag is False for unused parameters
self
.
grad_reduced_flag
=
False
# if this chunk is synchronized with the optimizer, the flag is True
# if this chunk is synchronized with the optimizer, the flag is True
self
.
optim_sync_flag
=
True
self
.
optim_sync_flag
=
True
# if the cpu_shard has been visited during the training step, the flag is True
# if the cpu_shard has been visited during the training step, the flag is True
self
.
cpu_vis_flag
=
False
self
.
cpu_vis_flag
=
False
@
property
@
property
def
memory_usage
(
self
):
def
memory_usage
(
self
)
->
Dict
[
str
,
int
]
:
cuda_memory
=
0
cuda_memory
=
0
cpu_memory
=
0
cpu_memory
=
0
...
@@ -112,7 +151,7 @@ class ChunkV2:
...
@@ -112,7 +151,7 @@ class ChunkV2:
return
dict
(
cuda
=
cuda_memory
,
cpu
=
cpu_memory
)
return
dict
(
cuda
=
cuda_memory
,
cpu
=
cpu_memory
)
@
property
@
property
def
device_type
(
self
):
def
device_type
(
self
)
->
str
:
if
self
.
chunk_temp
is
not
None
:
if
self
.
chunk_temp
is
not
None
:
return
self
.
chunk_temp
.
device
.
type
return
self
.
chunk_temp
.
device
.
type
else
:
else
:
...
@@ -123,6 +162,56 @@ class ChunkV2:
...
@@ -123,6 +162,56 @@ class ChunkV2:
else
:
else
:
return
'cpu'
return
'cpu'
@
property
def
payload
(
self
)
->
torch
.
Tensor
:
# sanity check
assert
self
.
chunk_temp
is
None
if
self
.
is_gathered
:
return
self
.
chunk_total
elif
self
.
cuda_shard
is
not
None
:
return
self
.
cuda_shard
else
:
return
self
.
cpu_shard
@
property
def
payload_mem
(
self
)
->
int
:
# sanity check
assert
self
.
chunk_temp
is
None
if
self
.
is_gathered
:
return
self
.
chunk_mem
else
:
return
self
.
shard_mem
@
property
def
can_move
(
self
)
->
bool
:
return
not
self
.
is_gathered
@
property
def
can_release
(
self
)
->
bool
:
if
self
.
keep_gathered
:
return
False
else
:
return
self
.
tensors_state_monitor
[
TensorState
.
HOLD
]
+
\
self
.
tensors_state_monitor
[
TensorState
.
HOLD_AFTER_BWD
]
==
self
.
num_tensors
@
property
def
can_reduce
(
self
):
return
self
.
tensors_state_monitor
[
TensorState
.
READY_FOR_REDUCE
]
==
self
.
num_tensors
@
property
def
has_inf_or_nan
(
self
)
->
bool
:
"""Check if the chunk has inf or nan values in CUDA.
"""
if
self
.
is_gathered
:
valid_tensor
=
self
.
chunk_total
[:
self
.
utilized_size
]
else
:
assert
self
.
cuda_shard
is
not
None
# only check in CUDA
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
return
torch
.
isinf
(
valid_tensor
).
any
().
item
()
|
torch
.
isnan
(
valid_tensor
).
any
().
item
()
def
append_tensor
(
self
,
tensor
:
torch
.
Tensor
):
def
append_tensor
(
self
,
tensor
:
torch
.
Tensor
):
"""Add a tensor to the chunk.
"""Add a tensor to the chunk.
...
@@ -150,7 +239,10 @@ class ChunkV2:
...
@@ -150,7 +239,10 @@ class ChunkV2:
self
.
utilized_size
=
new_utilized_size
self
.
utilized_size
=
new_utilized_size
def
close_chunk
(
self
,
shard_dev
:
Optional
[
torch
.
device
]
=
None
):
def
close_chunk
(
self
,
shard_dev
:
Optional
[
torch
.
device
]
=
None
):
"""Close the chunk. Any tensor can't be appended to a closed chunk.
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
Args:
shard_dev: the device where the shard locates
"""
"""
# sanity check
# sanity check
assert
self
.
chunk_temp
is
not
None
assert
self
.
chunk_temp
is
not
None
...
@@ -163,6 +255,7 @@ class ChunkV2:
...
@@ -163,6 +255,7 @@ class ChunkV2:
if
self
.
chunk_temp
.
device
.
type
==
'cpu'
:
if
self
.
chunk_temp
.
device
.
type
==
'cpu'
:
self
.
chunk_total
=
self
.
chunk_temp
.
to
(
get_current_device
())
self
.
chunk_total
=
self
.
chunk_temp
.
to
(
get_current_device
())
self
.
__update_tensors_ptr
()
else
:
else
:
self
.
chunk_total
=
self
.
chunk_temp
self
.
chunk_total
=
self
.
chunk_temp
self
.
chunk_temp
=
None
self
.
chunk_temp
=
None
...
@@ -186,6 +279,12 @@ class ChunkV2:
...
@@ -186,6 +279,12 @@ class ChunkV2:
self
.
cuda_shard
=
None
self
.
cuda_shard
=
None
def
shard_move
(
self
,
device
:
torch
.
device
,
force_copy
:
bool
=
False
):
def
shard_move
(
self
,
device
:
torch
.
device
,
force_copy
:
bool
=
False
):
"""Move the shard tensor in the chunk.
Args:
device: the device to which the shard will move
force_copy: if True, copy function is called mandatorily
"""
# sanity check
# sanity check
assert
not
self
.
is_gathered
assert
not
self
.
is_gathered
# when the current chunk is not synchronized with the optimizer
# when the current chunk is not synchronized with the optimizer
...
@@ -223,8 +322,7 @@ class ChunkV2:
...
@@ -223,8 +322,7 @@ class ChunkV2:
raise
NotImplementedError
raise
NotImplementedError
def
access_chunk
(
self
):
def
access_chunk
(
self
):
"""Make the chunk usable for the parameters inside it.
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
It is an operation done in CUDA.
"""
"""
# sanity check
# sanity check
assert
self
.
chunk_temp
is
None
assert
self
.
chunk_temp
is
None
...
@@ -234,8 +332,7 @@ class ChunkV2:
...
@@ -234,8 +332,7 @@ class ChunkV2:
self
.
__update_tensors_ptr
()
self
.
__update_tensors_ptr
()
def
release_chunk
(
self
):
def
release_chunk
(
self
):
"""Release the usable chunk.
"""Release the usable chunk. It's an operation done in CUDA.
It is an operation done in CUDA.
"""
"""
# sanity check
# sanity check
assert
self
.
chunk_temp
is
None
assert
self
.
chunk_temp
is
None
...
@@ -244,8 +341,7 @@ class ChunkV2:
...
@@ -244,8 +341,7 @@ class ChunkV2:
self
.
__scatter
()
self
.
__scatter
()
def
reduce
(
self
):
def
reduce
(
self
):
"""Reduce scatter all the gradients.
"""Reduce scatter all the gradients. It's an operation done in CUDA.
It is an operation done in CUDA.
"""
"""
# sanity check
# sanity check
assert
self
.
is_gathered
assert
self
.
is_gathered
...
@@ -267,7 +363,6 @@ class ChunkV2:
...
@@ -267,7 +363,6 @@ class ChunkV2:
free_storage
(
self
.
chunk_total
)
free_storage
(
self
.
chunk_total
)
self
.
is_gathered
=
False
self
.
is_gathered
=
False
self
.
__update_tensors_state
(
TensorState
.
HOLD
)
self
.
__update_tensors_state
(
TensorState
.
HOLD
)
self
.
grad_reduced_flag
=
True
def
tensor_trans_state
(
self
,
tensor
:
torch
.
Tensor
,
tensor_state
:
TensorState
)
->
None
:
def
tensor_trans_state
(
self
,
tensor
:
torch
.
Tensor
,
tensor_state
:
TensorState
)
->
None
:
"""
"""
...
@@ -285,9 +380,6 @@ class ChunkV2:
...
@@ -285,9 +380,6 @@ class ChunkV2:
# this function only apply valid state transformation
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
# invalid calls will be ignored and nothing changes
if
(
self
.
tensors_info
[
tensor
].
state
,
tensor_state
)
not
in
STATE_TRANS
:
if
(
self
.
tensors_info
[
tensor
].
state
,
tensor_state
)
not
in
STATE_TRANS
:
# print(
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
return
self
.
__update_one_tensor_info
(
self
.
tensors_info
[
tensor
],
tensor_state
)
self
.
__update_one_tensor_info
(
self
.
tensors_info
[
tensor
],
tensor_state
)
...
@@ -306,46 +398,56 @@ class ChunkV2:
...
@@ -306,46 +398,56 @@ class ChunkV2:
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
data
.
flatten
())
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
data
.
flatten
())
tensor
.
data
=
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
tensor
.
data
=
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
@
property
def
get_valid_length
(
self
)
->
int
:
def
can_move
(
self
)
->
bool
:
"""Get the valid length of the chunk's payload.
return
not
self
.
is_gathered
"""
@
property
def
can_release
(
self
)
->
bool
:
if
self
.
keep_gathered
:
if
self
.
keep_gathered
:
return
Fals
e
return
self
.
utilized_siz
e
else
:
else
:
return
self
.
tensors_state_monitor
[
TensorState
.
HOLD
]
+
\
return
self
.
valid_end
self
.
tensors_state_monitor
[
TensorState
.
HOLD_AFTER_BWD
]
==
self
.
num_tensors
@
property
def
init_pair
(
self
,
friend_chunk
:
'Chunk'
)
->
None
:
def
can_reduce
(
self
):
"""Initialize the paired chunk.
return
self
.
tensors_state_monitor
[
TensorState
.
READY_FOR_REDUCE
]
==
self
.
num_tensors
@
property
def
has_inf_or_nan
(
self
)
->
bool
:
"""
"""
Check if the chunk has inf or nan values in CUDA.
if
self
.
paired_chunk
is
None
and
friend_chunk
.
paired_chunk
is
None
:
self
.
paired_chunk
=
friend_chunk
friend_chunk
.
paired_chunk
=
self
else
:
assert
self
.
paired_chunk
is
friend_chunk
assert
friend_chunk
.
paired_chunk
is
self
def
optim_update
(
self
)
->
None
:
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
"""
"""
if
self
.
is_gathered
:
# sanity check
valid_tensor
=
self
.
chunk_total
[:
self
.
utilized_size
]
assert
self
.
paired_chunk
is
not
None
friend_chunk
=
self
.
paired_chunk
if
self
.
is_gathered
is
True
:
assert
friend_chunk
.
is_gathered
is
True
self
.
chunk_total
.
copy_
(
friend_chunk
.
chunk_total
)
self
.
optim_sync_flag
=
True
elif
friend_chunk
.
device_type
==
'cuda'
and
self
.
device_type
==
'cuda'
:
self
.
cuda_shard
.
copy_
(
friend_chunk
.
cuda_shard
)
self
.
optim_sync_flag
=
True
self
.
cpu_vis_flag
=
False
else
:
else
:
assert
self
.
cuda_shard
is
not
None
# only check in CUDA
assert
friend_chunk
.
device_type
==
'cpu'
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
assert
self
.
device_type
==
'cpu'
self
.
optim_sync_flag
=
False
self
.
cpu_vis_flag
=
False
return
torch
.
isinf
(
valid_tensor
).
any
().
item
()
|
torch
.
isnan
(
valid_tensor
).
any
().
item
()
def
get_tensors
(
self
)
->
List
[
torch
.
Tensor
]:
return
list
(
self
.
tensors_info
.
keys
())
def
__gather
(
self
):
def
__gather
(
self
):
if
not
self
.
is_gathered
:
if
not
self
.
is_gathered
:
# sanity check
# sanity check
assert
self
.
cuda_shard
is
not
None
assert
self
.
cuda_shard
is
not
None
if
self
.
pg_size
==
1
:
alloc_storage
(
self
.
chunk_total
)
self
.
chunk_total
=
self
.
cuda_shard
gather_list
=
list
(
torch
.
chunk
(
input
=
self
.
chunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
else
:
dist
.
all_gather
(
gather_list
,
self
.
cuda_shard
,
self
.
torch_pg
)
alloc_storage
(
self
.
chunk_total
)
gather_list
=
list
(
torch
.
chunk
(
input
=
self
.
chunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
all_gather
(
gather_list
,
self
.
cuda_shard
,
self
.
torch_pg
)
self
.
cuda_shard
=
None
self
.
cuda_shard
=
None
self
.
is_gathered
=
True
self
.
is_gathered
=
True
...
@@ -404,9 +506,9 @@ class ChunkV2:
...
@@ -404,9 +506,9 @@ class ChunkV2:
def
__eq__
(
self
,
__o
:
object
)
->
bool
:
def
__eq__
(
self
,
__o
:
object
)
->
bool
:
return
self
is
__o
return
self
is
__o
def
__repr__
(
self
,
detailed
:
bool
=
Fals
e
):
def
__repr__
(
self
,
detailed
:
bool
=
Tru
e
):
output
=
[
output
=
[
"
Ag
Chunk Information:
\n
"
,
"Chunk Information:
\n
"
,
"
\t
chunk size: {}, chunk dtype: {}, process group size: {}
\n
"
.
format
(
self
.
chunk_size
,
self
.
dtype
,
"
\t
chunk size: {}, chunk dtype: {}, process group size: {}
\n
"
.
format
(
self
.
chunk_size
,
self
.
dtype
,
self
.
pg_size
),
self
.
pg_size
),
"
\t
# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}
\n
"
.
format
(
"
\t
# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}
\n
"
.
format
(
...
@@ -442,6 +544,3 @@ class ChunkV2:
...
@@ -442,6 +544,3 @@ class ChunkV2:
output
.
append
(
"
\t\t
# of {}: {}
\n
"
.
format
(
st
,
self
.
tensors_state_monitor
[
st
]))
output
.
append
(
"
\t\t
# of {}: {}
\n
"
.
format
(
st
,
self
.
tensors_state_monitor
[
st
]))
return
''
.
join
(
output
)
return
''
.
join
(
output
)
def
get_tensors
(
self
)
->
List
[
torch
.
Tensor
]:
return
list
(
self
.
tensors_info
.
keys
())
colossalai/gemini/
update/chunk_mgrv2
.py
→
colossalai/gemini/
chunk/manager
.py
View file @
5be118f4
...
@@ -4,23 +4,19 @@ from collections import deque
...
@@ -4,23 +4,19 @@ from collections import deque
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor
import
ColoTensor
from
colossalai.gemini.chunk
import
ChunkFullError
,
TensorState
from
colossalai.gemini.chunk
import
ChunkFullError
,
TensorState
,
Chunk
from
colossalai.gemini.update
import
ChunkV2
as
Chunk
class
ChunkManager
V2
:
class
ChunkManager
:
"""
"""
A manager class to manipulate the tensors in chunks.
A manager class to manipulate the tensors in chunks.
Args:
Args:
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
pin_memory (bool): if ture, all chunks have a piece of pinned memory in CPU.
"""
"""
def
__init__
(
self
,
chunk_configuration
:
Dict
[
int
,
Dict
],
def
__init__
(
self
,
chunk_configuration
:
Dict
[
int
,
Dict
],
init_device
:
Optional
[
torch
.
device
]
=
None
)
->
None
:
init_device
:
Optional
[
torch
.
device
]
=
None
,
pin_memory
:
bool
=
False
)
->
None
:
self
.
device
=
init_device
or
get_current_device
()
self
.
device
=
init_device
or
get_current_device
()
self
.
size_config
:
Dict
[
int
,
int
]
=
dict
()
self
.
size_config
:
Dict
[
int
,
int
]
=
dict
()
...
@@ -28,7 +24,6 @@ class ChunkManagerV2:
...
@@ -28,7 +24,6 @@ class ChunkManagerV2:
for
k
,
v
in
self
.
kwargs_config
.
items
():
for
k
,
v
in
self
.
kwargs_config
.
items
():
self
.
size_config
[
k
]
=
v
.
pop
(
'chunk_size'
)
self
.
size_config
[
k
]
=
v
.
pop
(
'chunk_size'
)
v
[
'init_device'
]
=
self
.
device
v
[
'init_device'
]
=
self
.
device
v
[
'pin_memory'
]
=
pin_memory
self
.
chunk_groups
:
Dict
[
str
,
Deque
]
=
dict
()
self
.
chunk_groups
:
Dict
[
str
,
Deque
]
=
dict
()
self
.
tensor_chunk_map
:
Dict
[
torch
.
Tensor
,
Chunk
]
=
dict
()
self
.
tensor_chunk_map
:
Dict
[
torch
.
Tensor
,
Chunk
]
=
dict
()
...
@@ -36,8 +31,14 @@ class ChunkManagerV2:
...
@@ -36,8 +31,14 @@ class ChunkManagerV2:
self
.
lazy_release_tensors
:
List
[
torch
.
Tensor
]
=
list
()
self
.
lazy_release_tensors
:
List
[
torch
.
Tensor
]
=
list
()
self
.
total_mem
:
Dict
[
str
,
int
]
=
{
'cpu'
:
0
,
'cuda'
:
0
}
self
.
total_mem
:
Dict
[
str
,
int
]
=
{
'cpu'
:
0
,
'cuda'
:
0
}
def
append_tensor
(
self
,
tensor
:
ColoTensor
,
group_type
:
str
,
config_key
:
int
)
->
None
:
def
append_tensor
(
self
,
tensor
:
ColoTensor
,
group_type
:
str
,
config_key
:
int
,
pin_memory
:
bool
=
False
)
->
None
:
"""Append a tensor to a chunk.
"""Append a tensor to a chunk.
Args:
tensor: the tensor appended to the chunk
group_type: the data type of the group
config_key: the key of the group's name, usually the size of the dp world
pin_memory: whether the chunk is pinned in the cpu memory
"""
"""
assert
tensor
not
in
self
.
tensor_chunk_map
assert
tensor
not
in
self
.
tensor_chunk_map
assert
isinstance
(
tensor
,
ColoTensor
),
"Please feed ColoTensor to this ChunkManager"
assert
isinstance
(
tensor
,
ColoTensor
),
"Please feed ColoTensor to this ChunkManager"
...
@@ -66,7 +67,8 @@ class ChunkManagerV2:
...
@@ -66,7 +67,8 @@ class ChunkManagerV2:
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
process_group
=
tensor
.
process_group
,
process_group
=
tensor
.
process_group
,
dtype
=
tensor
.
dtype
,
dtype
=
tensor
.
dtype
,
**
chunk_kwargs
pin_memory
=
pin_memory
,
**
chunk_kwargs
,
)
)
chunk_group
.
append
(
chunk
)
chunk_group
.
append
(
chunk
)
...
@@ -87,6 +89,8 @@ class ChunkManagerV2:
...
@@ -87,6 +89,8 @@ class ChunkManagerV2:
if
chunk
in
self
.
accessed_chunks
:
if
chunk
in
self
.
accessed_chunks
:
return
return
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
if
chunk
.
device_type
==
'cpu'
:
chunk
.
shard_move
(
get_current_device
())
chunk
.
access_chunk
()
chunk
.
access_chunk
()
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
self
.
accessed_chunks
.
add
(
chunk
)
self
.
accessed_chunks
.
add
(
chunk
)
...
@@ -102,13 +106,13 @@ class ChunkManagerV2:
...
@@ -102,13 +106,13 @@ class ChunkManagerV2:
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
self
.
accessed_chunks
.
remove
(
chunk
)
self
.
accessed_chunks
.
remove
(
chunk
)
def
move_chunk
(
self
,
chunk
:
Chunk
,
device
:
torch
.
device
)
->
None
:
def
move_chunk
(
self
,
chunk
:
Chunk
,
device
:
torch
.
device
,
force_copy
:
bool
=
False
)
->
None
:
"""Move the shard of the chunk to the target device.
"""Move the shard of the chunk to the target device.
"""
"""
if
not
chunk
.
can_move
or
chunk
.
device_type
==
device
.
type
:
if
not
chunk
.
can_move
or
chunk
.
device_type
==
device
.
type
:
return
return
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
chunk
.
shard_move
(
device
)
chunk
.
shard_move
(
device
,
force_copy
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
def
trans_tensor_state
(
self
,
tensor
:
torch
.
Tensor
,
state
:
TensorState
)
->
None
:
def
trans_tensor_state
(
self
,
tensor
:
torch
.
Tensor
,
state
:
TensorState
)
->
None
:
...
@@ -123,7 +127,7 @@ class ChunkManagerV2:
...
@@ -123,7 +127,7 @@ class ChunkManagerV2:
if
not
chunk
.
can_reduce
:
if
not
chunk
.
can_reduce
:
return
False
return
False
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
chunk
.
re
lease_chunk
()
chunk
.
re
duce
()
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
return
True
return
True
...
@@ -165,14 +169,14 @@ class ChunkManagerV2:
...
@@ -165,14 +169,14 @@ class ChunkManagerV2:
self
.
release_chunk
(
chunk
)
self
.
release_chunk
(
chunk
)
self
.
lazy_release_tensors
.
clear
()
self
.
lazy_release_tensors
.
clear
()
def
__repr__
(
self
)
->
str
:
def
get_cuda_movable_chunks
(
self
,
group_type
:
str
)
->
List
[
Chunk
]
:
msg
=
[
'Chunk Manager Information:
\n
'
,
chunk_list
=
[]
'Total memory: '
+
', '
.
join
([
f
'
{
k
}
=
{
v
}
B'
for
k
,
v
in
self
.
total_mem
.
items
()])
+
'
\n
'
]
for
group_name
in
self
.
chunk_groups
:
for
group_
name
,
group
in
self
.
chunk_groups
.
items
()
:
if
group_
type
in
group_name
:
msg
.
append
(
f
'G
roup
{
group_name
}
:
\n
'
)
for
chunk
in
self
.
chunk_g
roup
s
[
group_name
]:
for
i
,
chunk
in
enumerate
(
group
)
:
i
f
chunk
.
device_type
==
'cuda'
and
chunk
.
can_move
:
msg
.
append
(
f
'[
{
i
}
]
{
chunk
}
\n
'
)
chunk_list
.
append
(
chunk
)
return
''
.
join
(
msg
)
return
chunk_list
def
get_chunks
(
self
,
tensors
:
Iterable
[
torch
.
Tensor
])
->
Tuple
[
Chunk
,
...]:
def
get_chunks
(
self
,
tensors
:
Iterable
[
torch
.
Tensor
])
->
Tuple
[
Chunk
,
...]:
"""
"""
...
@@ -200,6 +204,17 @@ class ChunkManagerV2:
...
@@ -200,6 +204,17 @@ class ChunkManagerV2:
assert
tensor
not
in
self
.
tensor_chunk_map
assert
tensor
not
in
self
.
tensor_chunk_map
self
.
total_mem
[
tensor
.
device
.
type
]
+=
tensor
.
numel
()
*
tensor
.
element_size
()
self
.
total_mem
[
tensor
.
device
.
type
]
+=
tensor
.
numel
()
*
tensor
.
element_size
()
def
__repr__
(
self
)
->
str
:
msg
=
[
'Chunk Manager Information:
\n
'
,
'Total memory: '
+
', '
.
join
([
f
'
{
k
}
=
{
v
}
B'
for
k
,
v
in
self
.
total_mem
.
items
()])
+
'
\n
'
]
for
group_name
,
group
in
self
.
chunk_groups
.
items
():
msg
.
append
(
f
'Group
{
group_name
}
:
\n
'
)
for
i
,
chunk
in
enumerate
(
group
):
msg
.
append
(
f
'[
{
i
}
]
{
chunk
}
\n
'
)
return
''
.
join
(
msg
)
def
__get_chunk_group
(
self
,
group_name
:
str
)
->
Deque
:
def
__get_chunk_group
(
self
,
group_name
:
str
)
->
Deque
:
"""Register a chunk group.
"""Register a chunk group.
"""
"""
...
@@ -208,8 +223,9 @@ class ChunkManagerV2:
...
@@ -208,8 +223,9 @@ class ChunkManagerV2:
return
self
.
chunk_groups
[
group_name
]
return
self
.
chunk_groups
[
group_name
]
def
__close_one_chunk
(
self
,
chunk
:
Chunk
):
def
__close_one_chunk
(
self
,
chunk
:
Chunk
):
device
=
get_current_device
()
if
chunk
.
keep_gathered
else
self
.
device
# keep gathered chunk in cuda
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
self
.
__sub_memroy_usage
(
chunk
.
memory_usage
)
chunk
.
close_chunk
(
self
.
device
)
chunk
.
close_chunk
(
device
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
def
__sub_memroy_usage
(
self
,
usage
:
Dict
[
str
,
int
]):
def
__sub_memroy_usage
(
self
,
usage
:
Dict
[
str
,
int
]):
...
...
colossalai/gemini/
update
/search_utils.py
→
colossalai/gemini/
chunk
/search_utils.py
View file @
5be118f4
import
math
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
import
numpy
as
np
import
numpy
as
np
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -7,7 +8,7 @@ from colossalai.tensor import ColoParameter
...
@@ -7,7 +8,7 @@ from colossalai.tensor import ColoParameter
def
_filter_exlarge_params
(
model
:
nn
.
Module
,
size_dict
:
Dict
[
int
,
List
[
int
]])
->
None
:
def
_filter_exlarge_params
(
model
:
nn
.
Module
,
size_dict
:
Dict
[
int
,
List
[
int
]])
->
None
:
"""Filter those parameters whose size is too large from others.
"""Filter those parameters whose size is too large from others.
"""
"""
params_size
=
[
p
.
numel
()
for
p
in
model
.
parameters
()]
params_size
=
[
p
.
numel
()
for
p
in
model
.
parameters
()
if
not
getattr
(
p
,
'_ddp_to_ignore'
,
False
)
]
params_size_arr
=
np
.
array
(
params_size
)
params_size_arr
=
np
.
array
(
params_size
)
std
=
np
.
std
(
params_size_arr
)
std
=
np
.
std
(
params_size_arr
)
...
@@ -36,6 +37,9 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
...
@@ -36,6 +37,9 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
params_dict
:
Dict
[
int
,
List
[
ColoParameter
]]
=
dict
()
params_dict
:
Dict
[
int
,
List
[
ColoParameter
]]
=
dict
()
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
assert
isinstance
(
param
,
ColoParameter
),
"please init model in the ColoInitContext"
assert
isinstance
(
param
,
ColoParameter
),
"please init model in the ColoInitContext"
if
getattr
(
param
,
'_ddp_to_ignore'
,
False
):
continue
param_key
=
param
.
process_group
.
dp_world_size
()
param_key
=
param
.
process_group
.
dp_world_size
()
if
param_key
not
in
params_dict
:
if
param_key
not
in
params_dict
:
...
@@ -47,13 +51,13 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
...
@@ -47,13 +51,13 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
def
search_chunk_configuration
(
def
search_chunk_configuration
(
model
:
nn
.
Module
,
model
:
nn
.
Module
,
search_range_mb
:
in
t
,
search_range_mb
:
floa
t
,
search_interval_byte
:
int
,
# hidden size is the best value for the interval
search_interval_byte
:
int
,
# hidden size is the best value for the interval
min_chunk_size_mb
:
in
t
=
32
,
min_chunk_size_mb
:
floa
t
=
32
,
filter_exlarge_params
:
bool
=
True
):
filter_exlarge_params
:
bool
=
True
)
->
Dict
:
search_range_byte
=
search_range_mb
*
1024
**
2
search_range_byte
=
round
(
search_range_mb
*
1024
**
2
)
min_chunk_size_byte
=
min_chunk_size_mb
*
1024
**
2
min_chunk_size_byte
=
round
(
min_chunk_size_mb
*
1024
**
2
)
assert
search_range_byte
%
search_interval_byte
=
=
0
assert
search_range_byte
>
=
0
params_dict
=
clasify_params
(
model
)
params_dict
=
clasify_params
(
model
)
config_dict
:
Dict
[
int
,
Dict
]
=
dict
()
config_dict
:
Dict
[
int
,
Dict
]
=
dict
()
...
@@ -75,11 +79,12 @@ def search_chunk_configuration(
...
@@ -75,11 +79,12 @@ def search_chunk_configuration(
max_size
=
min_chunk_size_byte
max_size
=
min_chunk_size_byte
for
key
in
size_dict
:
for
key
in
size_dict
:
max_size
=
max
(
max_size
,
max
(
size_dict
[
key
]))
max_size
=
max
(
max_size
,
max
(
size_dict
[
key
]))
start_size
=
int
(
math
.
ceil
(
max_size
/
search_interval_byte
)
*
search_interval_byte
)
min_chunk_waste
=
float
(
'+inf'
)
min_chunk_waste
=
float
(
'+inf'
)
best_chunk_size
=
max
_size
best_chunk_size
=
start
_size
for
chunk_size
in
range
(
max
_size
,
max
_size
+
search_range_byte
+
1
,
search_interval_byte
):
for
chunk_size
in
range
(
start
_size
,
start
_size
+
search_range_byte
+
1
,
search_interval_byte
):
temp_waste
=
0
temp_waste
=
0
for
key
in
size_dict
:
for
key
in
size_dict
:
temp_waste
+=
_get_unused_byte
(
size_dict
[
key
],
chunk_size
)
temp_waste
+=
_get_unused_byte
(
size_dict
[
key
],
chunk_size
)
...
...
colossalai/gemini/chunk_mgr.py
deleted
100644 → 0
View file @
f9217336
import
torch
import
numpy
as
np
from
typing
import
Optional
,
Dict
,
Deque
,
Set
,
List
,
Tuple
,
Iterable
from
collections
import
deque
from
colossalai.utils
import
get_current_device
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
,
ColoTensor
from
.chunk
import
Chunk
,
ChunkFullError
,
TensorState
class
ChunkManager
:
"""
A manager class to manipulate the tensors in chunks.
Args:
chunk_size (int): the size of a chunk.
process_group (ColoProcessGroup): process group of the chunk.
enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false.
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""
def
__init__
(
self
,
chunk_size
:
Optional
[
int
],
process_group
:
ColoProcessGroup
,
enable_distributed_storage
:
bool
=
False
,
init_device
:
Optional
[
torch
.
device
]
=
None
)
->
None
:
assert
chunk_size
is
None
or
chunk_size
>
0
assert
isinstance
(
process_group
,
ColoProcessGroup
)
self
.
chunk_size
=
chunk_size
self
.
process_group
=
process_group
self
.
enable_distributed_storage
=
enable_distributed_storage
self
.
device
=
init_device
or
get_current_device
()
self
.
chunk_groups
:
Dict
[
str
,
Deque
[
Chunk
]]
=
{}
self
.
groups_force_data_on_cuda
:
Dict
[
str
,
bool
]
=
{}
self
.
tensor_chunk_map
:
Dict
[
torch
.
Tensor
,
Chunk
]
=
{}
self
.
accessed_chunks
:
Set
[
Chunk
]
=
set
()
self
.
lazy_release_tensors
:
List
[
torch
.
Tensor
]
=
[]
if
enable_distributed_storage
and
chunk_size
is
None
:
self
.
rank_load
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
total_mem
:
Dict
[
str
,
int
]
=
{
'cpu'
:
0
,
'cuda'
:
0
}
def
create_group
(
self
,
group_name
:
str
,
force_data_on_cuda
:
bool
=
False
)
->
None
:
"""Create a chunk group.
Args:
group_name (str): group name
force_data_on_cuda (bool, optional): If True, the data of chunks in this group is always on cuda.. Defaults to False.
"""
assert
group_name
not
in
self
.
chunk_groups
self
.
chunk_groups
[
group_name
]
=
deque
()
self
.
groups_force_data_on_cuda
[
group_name
]
=
force_data_on_cuda
def
append_tensor
(
self
,
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
"""
Append a tensor to a chunk.
Args:
tensor (torch.Tensor): a tensor to append to the chunk.
group_name (str): the name of the chunk group.
"""
assert
tensor
not
in
self
.
tensor_chunk_map
if
isinstance
(
tensor
,
ColoTensor
):
assert
tensor
.
get_process_group
().
dp_process_group
()
==
self
.
process_group
.
dp_process_group
(
),
f
"Chunk Manager can only manage ColoTensor with the same DP process group"
try
:
# append the tensor to the last chunk
self
.
chunk_groups
[
group_name
][
-
1
].
append
(
tensor
)
except
(
IndexError
,
ChunkFullError
):
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
if
self
.
chunk_size
is
not
None
and
tensor
.
numel
()
>
self
.
chunk_size
:
chunk_size
=
tensor
.
numel
()
else
:
chunk_size
=
self
.
chunk_size
or
tensor
.
numel
()
src_rank
=
self
.
_get_next_src_rank
(
group_name
)
chunk
=
Chunk
(
chunk_size
,
src_rank
,
self
.
process_group
,
tensor
.
dtype
,
self
.
device
,
force_data_on_cuda
=
self
.
groups_force_data_on_cuda
[
group_name
])
if
self
.
enable_distributed_storage
and
self
.
chunk_size
is
None
:
self
.
rank_load
[
group_name
][
src_rank
]
+=
chunk_size
self
.
chunk_groups
[
group_name
].
append
(
chunk
)
chunk
.
append
(
tensor
)
if
not
chunk
.
is_empty
:
self
.
total_mem
[
chunk
.
device_type
]
+=
chunk
.
mem
self
.
tensor_chunk_map
[
tensor
]
=
self
.
chunk_groups
[
group_name
][
-
1
]
if
not
self
.
enable_distributed_storage
:
# as distributed storage is not enabled, there is no need to broadcast
# chunks, thus we set these chunks as accessed
self
.
accessed_chunks
.
add
(
self
.
chunk_groups
[
group_name
][
-
1
])
def
_get_next_src_rank
(
self
,
group_name
:
str
)
->
int
:
if
not
self
.
enable_distributed_storage
:
# the chunk is owned by the current rank if no distributed storage is enabled
return
self
.
process_group
.
dp_local_rank
()
if
self
.
chunk_size
is
None
:
if
group_name
not
in
self
.
rank_load
:
self
.
rank_load
[
group_name
]
=
torch
.
zeros
(
self
.
process_group
.
dp_world_size
(),
dtype
=
torch
.
int64
)
# the process owning the tensor will be the process with the smallest number of elements
src_rank
=
torch
.
argmin
(
self
.
rank_load
[
group_name
]).
item
()
else
:
# chunk is owned by processes in a round-robin fashion
chunk_idx
=
len
(
self
.
chunk_groups
[
group_name
])
src_rank
=
chunk_idx
%
self
.
process_group
.
dp_world_size
()
return
src_rank
def
access_chunk
(
self
,
chunk
:
Chunk
)
->
None
:
"""
Synchronize the chunks via broadcast.
Args:
chunk (Chunk): the chunk to synchronize.
"""
if
chunk
in
self
.
accessed_chunks
:
if
chunk
.
device_type
!=
'cuda'
:
self
.
total_mem
[
chunk
.
device_type
]
-=
chunk
.
mem
chunk
.
move_device
(
get_current_device
())
self
.
total_mem
[
chunk
.
device_type
]
+=
chunk
.
mem
return
if
not
chunk
.
is_empty
:
# as tensor is moved to the target device
# the memory consumption of the original device is reduced
self
.
total_mem
[
chunk
.
device_type
]
-=
chunk
.
mem
chunk
.
access
()
self
.
accessed_chunks
.
add
(
chunk
)
self
.
total_mem
[
chunk
.
device_type
]
+=
chunk
.
mem
def
release_chunk
(
self
,
chunk
:
Chunk
)
->
None
:
"""
Release the memory space of a chunk.
Args:
chunk (Chunk): the chunk to release memory space
"""
if
not
self
.
enable_distributed_storage
:
return
if
chunk
not
in
self
.
accessed_chunks
:
return
if
chunk
.
can_release
:
chunk
.
release
()
self
.
accessed_chunks
.
remove
(
chunk
)
if
chunk
.
is_empty
:
# update the memory consumption after releasing
self
.
total_mem
[
chunk
.
device_type
]
-=
chunk
.
mem
def
move_chunk
(
self
,
chunk
:
Chunk
,
device
:
torch
.
device
,
update_ptr
:
bool
=
True
)
->
None
:
"""
Move the chunk to the target device.
Args:
chunk (Chunk): the chunk to move to target device
device (torch.device): target device
"""
if
chunk
.
device_type
==
device
.
type
:
return
if
chunk
.
can_move_device
and
not
chunk
.
is_empty
:
self
.
total_mem
[
chunk
.
device_type
]
-=
chunk
.
mem
chunk
.
move_device
(
device
,
update_ptr
=
update_ptr
)
self
.
total_mem
[
chunk
.
device_type
]
+=
chunk
.
mem
def
trans_tensor_state
(
self
,
tensor
:
torch
.
Tensor
,
state
:
TensorState
)
->
None
:
"""
Transit tensor state according to pre-defined state machine.
Args:
tensor (torch.Tensor): the tensor for state transititon
state (TensorState): next tensor state for transtition
"""
chunk
=
self
.
tensor_chunk_map
[
tensor
]
chunk
.
tensor_trans_state
(
tensor
,
state
)
def
reduce_chunk
(
self
,
chunk
:
Chunk
)
->
bool
:
"""
Reduce or all reduce the chunk. If enable_distributed_storage is true, all-reduce is used.
Otherwise, this method uses reduce.
Args:
chunk (Chunk): the chunk for reduction.
"""
if
not
chunk
.
can_reduce
:
return
False
self
.
total_mem
[
chunk
.
device_type
]
-=
chunk
.
mem
chunk
.
reduce
(
is_all_reduce
=
not
self
.
enable_distributed_storage
)
self
.
total_mem
[
chunk
.
device_type
]
+=
chunk
.
mem
return
True
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data
:
torch
.
Tensor
)
->
None
:
"""
Copy data to the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data (torch.Tensor): the tensor to be copied to the chunk
"""
chunk
=
self
.
tensor_chunk_map
[
tensor
]
chunk
.
copy_tensor_to_chunk_slice
(
tensor
,
data
)
def
get_chunk
(
self
,
tensor
:
torch
.
Tensor
)
->
Chunk
:
"""
Return the chunk owning the tensor.
Args:
tensor (torch.Tensor): a torch tensor object
"""
return
self
.
tensor_chunk_map
[
tensor
]
def
add_lazy_release_tensors
(
self
,
tensors
:
List
[
torch
.
Tensor
])
->
None
:
"""
Add tensors to the buffer for lazy release.
Args:
tensors (List[torch.Tensor]): the tensors to be released lazily
"""
self
.
lazy_release_tensors
.
extend
(
tensors
)
def
exec_lazy_release
(
self
)
->
None
:
"""
Execute release for tensors added to the lazy release buffer.
"""
for
chunk
in
self
.
get_chunks
(
self
.
lazy_release_tensors
):
self
.
release_chunk
(
chunk
)
self
.
lazy_release_tensors
.
clear
()
def
__repr__
(
self
)
->
str
:
msg
=
f
'Rank
{
self
.
process_group
.
dp_local_rank
()
}
:
\n
'
msg
+=
'Total memory: '
+
', '
.
join
([
f
'
{
k
}
=
{
v
}
B'
for
k
,
v
in
self
.
total_mem
.
items
()])
+
'
\n
'
for
group_name
,
group
in
self
.
chunk_groups
.
items
():
msg
+=
f
'Group
{
group_name
}
:
\n
'
for
i
,
chunk
in
enumerate
(
group
):
msg
+=
f
'[
{
i
}
]
{
chunk
}
\n
'
return
msg
@
staticmethod
def
get_chunk_util
(
chunk_size
:
int
,
params_numel
:
List
[
int
])
->
float
:
"""
Calculate the utilization rate of a chunk.
Args:
chunk_size (int): the size of a chunk
params_numel (List[int]): the list of integers representing the number of elements of parameters
"""
assert
len
(
params_numel
)
>
0
total_size
=
0
total_utilized_size
=
0
cur_chunk_utilized_size
=
0
for
size
in
params_numel
:
assert
chunk_size
>=
size
total_utilized_size
+=
size
if
total_size
==
0
or
cur_chunk_utilized_size
+
size
>
chunk_size
:
total_size
+=
chunk_size
cur_chunk_utilized_size
=
0
cur_chunk_utilized_size
+=
size
return
total_utilized_size
/
total_size
@
staticmethod
def
search_chunk_size
(
module
:
torch
.
nn
.
Module
,
search_range
:
int
,
n_grids
:
int
,
min_chunk_size
:
Optional
[
int
]
=
None
,
filter_exlarge_params
:
bool
=
True
)
->
int
:
"""
Search for the chunk size for optimal chunk utilization.
Args:
module (torch.nn.Module): a torch module object
search_range (int): the range of chunk size to search. The actual search range will be from
max(min_chunk_size, max_param_size) to max(min_chunk_size, max_param_size) + search_range.
n_grids (int): the number of intervals in the search range
min_chunk_size (int): optional, the minimum size for a chunk. The default is None.
"""
assert
search_range
%
n_grids
==
0
# TODO(ver217): sort params and filter unused ones
params_numel
=
[
p
.
numel
()
for
p
in
module
.
parameters
()]
if
filter_exlarge_params
:
params_numel
=
_filter_exlarge_params
(
params_numel
)
max_param_numel
=
max
(
params_numel
)
if
min_chunk_size
is
not
None
:
assert
min_chunk_size
>=
max_param_numel
else
:
min_chunk_size
=
max_param_numel
step_size
=
search_range
//
n_grids
max_chunk_util
=
-
1
best_chunk_size
=
-
1
for
chunk_size
in
range
(
min_chunk_size
,
min_chunk_size
+
search_range
+
1
,
step_size
):
chunk_util
=
ChunkManager
.
get_chunk_util
(
chunk_size
,
params_numel
)
if
chunk_util
>
max_chunk_util
:
max_chunk_util
=
chunk_util
best_chunk_size
=
chunk_size
return
best_chunk_size
def
copy_chunk_group
(
self
,
dest_group_name
:
str
,
src_group_name
:
str
):
"""
Copy chunk data from one group to another group.
Args:
dest_group_name (str): the destination group which receives the copied data
src_group_name (str): the source group which provides the data to copy
"""
for
dest_chunk
,
src_chunk
in
zip
(
self
.
chunk_groups
[
dest_group_name
],
self
.
chunk_groups
[
src_group_name
]):
if
not
dest_chunk
.
is_empty
:
dest_chunk
.
copy_
(
src_chunk
)
def
get_chunks
(
self
,
tensors
:
Iterable
[
torch
.
Tensor
])
->
Tuple
[
Chunk
,
...]:
"""
Get all chunks owning the input tensors.
Args:
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
"""
chunks
=
[]
for
tensor
in
tensors
:
chunk
=
self
.
get_chunk
(
tensor
)
if
chunk
not
in
chunks
:
chunks
.
append
(
chunk
)
return
tuple
(
chunks
)
def
add_extern_static_tensor
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
"""Add extern static tensor to chunk manager.
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
They are "static", which means their shape, dtype, device never change.
Thus, their memory usage never changes.
Args:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert
tensor
not
in
self
.
tensor_chunk_map
self
.
total_mem
[
tensor
.
device
.
type
]
+=
tensor
.
numel
()
*
tensor
.
element_size
()
def
_filter_exlarge_params
(
params_numel
:
List
[
int
])
->
List
[
int
]:
params_numel_arr
=
np
.
array
(
params_numel
)
std
=
np
.
std
(
params_numel_arr
)
mean
=
np
.
mean
(
params_numel_arr
)
upper_limit
=
mean
+
3
*
std
return
list
(
filter
(
lambda
x
:
x
<=
upper_limit
,
params_numel
))
colossalai/gemini/gemini_mgr.py
View file @
5be118f4
...
@@ -3,7 +3,7 @@ import functools
...
@@ -3,7 +3,7 @@ import functools
from
.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
from
time
import
time
from
time
import
time
from
colossalai.gemini
import
Chunk
,
ChunkManager
from
colossalai.gemini
.chunk
import
Chunk
,
ChunkManager
from
.placement_policy
import
PlacementPolicyFactory
from
.placement_policy
import
PlacementPolicyFactory
...
@@ -56,37 +56,44 @@ class GeminiManager:
...
@@ -56,37 +56,44 @@ class GeminiManager:
self
.
_evict_time
=
0
self
.
_evict_time
=
0
self
.
_comp_cuda_demand_time
=
0
self
.
_comp_cuda_demand_time
=
0
def
adjust_layout
(
self
,
chunks
:
Tuple
[
Chunk
,
...],
group_
nam
e
:
str
)
->
None
:
def
adjust_layout
(
self
,
chunks
:
Tuple
[
Chunk
,
...],
group_
typ
e
:
str
)
->
None
:
""" Adjust the layout of statefuil tensor according to the information provided
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
by mem_stats_collector, which should belongs to a Sharded Model.
"""
"""
# find stateful tensor in state COMPUTE
# find stateful tensor in state COMPUTE
start
=
time
()
start
=
time
()
self
.
_record_chunks_order
(
chunks
)
self
.
_record_chunks_order
(
chunks
)
cuda_demand
,
hold_cuda_tensor_list
=
self
.
_get_layout_info
(
self
.
_compute_idx
,
self
.
_warmup
,
chunks
,
group_
nam
e
)
cuda_demand
,
hold_cuda_tensor_list
=
self
.
_get_layout_info
(
self
.
_compute_idx
,
self
.
_warmup
,
chunks
,
group_
typ
e
)
self
.
_layout_time
+=
time
()
-
start
self
.
_layout_time
+=
time
()
-
start
vol
,
evict_time
=
self
.
_placement_policy
.
evict_tensors
(
hold_cuda_tensor_list
,
vol
,
evict_time
=
self
.
_placement_policy
.
evict_tensors
(
can_evict_chunks
=
hold_cuda_tensor_list
,
cuda_demand
=
cuda_demand
,
cuda_demand
=
cuda_demand
,
warmup
=
self
.
_warmup
,
warmup
=
self
.
_warmup
,
compute_list
=
self
.
_compute_list
,
compute_list
=
self
.
_compute_list
,
compute_idx
=
self
.
_compute_idx
)
compute_idx
=
self
.
_compute_idx
)
self
.
_d2h_volume
+=
vol
self
.
_d2h_volume
+=
vol
self
.
_evict_time
+=
evict_time
self
.
_evict_time
+=
evict_time
# move COMPUTE tensors to CUDA
# move COMPUTE tensors to CUDA
self
.
_h2d_volume
+=
cuda_demand
self
.
_h2d_volume
+=
cuda_demand
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_layout_info
(
self
,
compute_idx
:
int
,
warmup
:
bool
,
chunks
:
Tuple
[
Chunk
,
...],
group_
nam
e
:
str
):
def
_get_layout_info
(
self
,
compute_idx
:
int
,
warmup
:
bool
,
chunks
:
Tuple
[
Chunk
,
...],
group_
typ
e
:
str
):
start
=
time
()
start
=
time
()
cuda_demand
=
0
cuda_demand
=
0
for
chunk
in
chunks
:
for
chunk
in
chunks
:
if
chunk
.
device_type
==
'cpu'
or
chunk
.
is_empty
:
if
chunk
.
device_type
==
'cuda'
:
cuda_demand
+=
chunk
.
mem
if
chunk
.
is_gathered
:
pass
else
:
cuda_demand
+=
chunk
.
chunk_mem
-
chunk
.
shard_mem
elif
chunk
.
device_type
==
'cpu'
:
cuda_demand
+=
chunk
.
chunk_mem
else
:
raise
RuntimeError
self
.
_comp_cuda_demand_time
+=
time
()
-
start
self
.
_comp_cuda_demand_time
+=
time
()
-
start
can_evict_chunks
=
[]
for
chunk
in
self
.
_chunk_manager
.
chunk_groups
[
group_name
]:
can_evict_chunks
=
self
.
_chunk_manager
.
get_cuda_movable_chunks
(
group_type
)
if
not
chunk
.
is_empty
and
chunk
.
device_type
==
'cuda'
and
chunk
.
can_move_device
:
can_evict_chunks
.
append
(
chunk
)
return
cuda_demand
,
can_evict_chunks
return
cuda_demand
,
can_evict_chunks
def
_record_chunks_order
(
self
,
chunks
:
Tuple
[
Chunk
,
...])
->
None
:
def
_record_chunks_order
(
self
,
chunks
:
Tuple
[
Chunk
,
...])
->
None
:
...
...
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
5be118f4
...
@@ -2,7 +2,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
...
@@ -2,7 +2,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from
colossalai.utils.memory
import
colo_device_memory_used
,
colo_device_memory_capacity
from
colossalai.utils.memory
import
colo_device_memory_used
,
colo_device_memory_capacity
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.gemini
import
ChunkManager
from
colossalai.gemini
.chunk
import
ChunkManager
import
torch
import
torch
import
time
import
time
...
...
colossalai/gemini/placement_policy.py
View file @
5be118f4
...
@@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity
...
@@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity
from
colossalai.gemini.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
colossalai.gemini.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
typing
import
Type
from
typing
import
Type
import
functools
import
functools
from
colossalai.gemini
import
Chunk
,
ChunkManager
from
colossalai.gemini
.chunk
import
Chunk
,
ChunkManager
class
PlacementPolicy
(
ABC
):
class
PlacementPolicy
(
ABC
):
...
@@ -19,7 +19,7 @@ class PlacementPolicy(ABC):
...
@@ -19,7 +19,7 @@ class PlacementPolicy(ABC):
self
.
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
mem_stats_collector
self
.
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
mem_stats_collector
@
abstractmethod
@
abstractmethod
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
None
:
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
Tuple
[
int
,
float
]
:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
@
staticmethod
...
@@ -32,12 +32,12 @@ class CPUPlacementPolicy(PlacementPolicy):
...
@@ -32,12 +32,12 @@ class CPUPlacementPolicy(PlacementPolicy):
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
def
__init__
(
self
,
chunk_manager
:
ChunkManager
,
mem_stats_collector
:
Optional
[
MemStatsCollectorV2
]
=
None
)
->
None
:
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
int
:
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
Tuple
[
int
,
float
]
:
volume
=
0
volume
=
0
start
=
time
()
start
=
time
()
for
chunk
in
can_evict_chunks
:
for
chunk
in
can_evict_chunks
:
self
.
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
)
,
update_ptr
=
False
)
self
.
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
))
volume
+=
chunk
.
mem
volume
+=
chunk
.
shard_
mem
return
volume
,
time
()
-
start
return
volume
,
time
()
-
start
...
@@ -47,7 +47,7 @@ class CUDAPlacementPolicy(PlacementPolicy):
...
@@ -47,7 +47,7 @@ class CUDAPlacementPolicy(PlacementPolicy):
assert
torch
.
cuda
.
is_available
(),
'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
assert
torch
.
cuda
.
is_available
(),
'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
super
().
__init__
(
chunk_manager
,
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
int
:
def
evict_tensors
(
self
,
can_evict_chunks
:
List
[
Chunk
],
**
kwargs
)
->
Tuple
[
int
,
float
]
:
return
0
,
0
return
0
,
0
@
staticmethod
@
staticmethod
...
@@ -59,7 +59,8 @@ class AutoPlacementPolicy(PlacementPolicy):
...
@@ -59,7 +59,8 @@ class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats
:
bool
=
True
need_mem_stats
:
bool
=
True
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
_warmup_non_model_data_ratio
:
float
=
0.8
_warmup_non_model_data_ratio
:
float
=
0.8
_steady_cuda_cap_ratio
:
float
=
0.9
_steady_cuda_cap_ratio
:
float
=
0.9
...
@@ -70,14 +71,14 @@ class AutoPlacementPolicy(PlacementPolicy):
...
@@ -70,14 +71,14 @@ class AutoPlacementPolicy(PlacementPolicy):
can_evict_chunks
:
List
[
Chunk
],
can_evict_chunks
:
List
[
Chunk
],
cuda_demand
:
int
=
0
,
cuda_demand
:
int
=
0
,
warmup
:
bool
=
True
,
warmup
:
bool
=
True
,
compute_list
:
List
[
Tuple
[
Chunk
,
...]]
=
[]
,
compute_list
:
Optional
[
List
[
Tuple
[
Chunk
,
...]]
]
=
None
,
compute_idx
:
int
=
0
,
compute_idx
:
int
=
0
,
**
kwargs
)
->
int
:
**
kwargs
)
->
Tuple
[
int
,
float
]
:
"""
"""
Evict tensors from CUDA device.
Evict tensors from CUDA device.
Args:
Args:
hold_cuda_tensor_list
(List[StatefulTensor]): the list of tensor
in state of HOLD-like
can_evict_chunks
(List[StatefulTensor]): the list of tensor
s that can be evicted.
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
...
@@ -114,12 +115,12 @@ class AutoPlacementPolicy(PlacementPolicy):
...
@@ -114,12 +115,12 @@ class AutoPlacementPolicy(PlacementPolicy):
for
chunk
in
to_free_chunks
:
for
chunk
in
to_free_chunks
:
if
freed_cuda_model_data
>=
to_free_cuda_model_data
:
if
freed_cuda_model_data
>=
to_free_cuda_model_data
:
break
break
freed_cuda_model_data
+=
chunk
.
mem
self
.
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
),
update_ptr
=
False
)
self
.
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
))
freed_cuda_model_data
+=
chunk
.
shard_mem
if
freed_cuda_model_data
<
to_free_cuda_model_data
:
if
freed_cuda_model_data
<
to_free_cuda_model_data
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Adjust layout failed! No enough CUDA memory! "
f
"Adjust layout failed! No enough CUDA memory! Need
{
to_free_cuda_model_data
}
, freed
{
freed_cuda_model_data
}
"
f
"Need
{
to_free_cuda_model_data
}
, freed
{
freed_cuda_model_data
}
"
)
)
return
freed_cuda_model_data
,
time
()
-
start
return
freed_cuda_model_data
,
time
()
-
start
@
staticmethod
@
staticmethod
...
@@ -147,7 +148,7 @@ class AutoPlacementPolicy(PlacementPolicy):
...
@@ -147,7 +148,7 @@ class AutoPlacementPolicy(PlacementPolicy):
class
PlacementPolicyFactory
:
class
PlacementPolicyFactory
:
policies
:
Dict
[
str
,
PlacementPolicy
]
=
{
policies
:
Dict
[
str
,
Type
[
PlacementPolicy
]
]
=
{
'cpu'
:
CPUPlacementPolicy
,
'cpu'
:
CPUPlacementPolicy
,
'cuda'
:
CUDAPlacementPolicy
,
'cuda'
:
CUDAPlacementPolicy
,
'auto'
:
AutoPlacementPolicy
'auto'
:
AutoPlacementPolicy
...
...
colossalai/gemini/stateful_tensor_container.py
deleted
100644 → 0
View file @
f9217336
import
queue
import
heapq
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
,
List
,
Dict
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
,
TensorState
def
evict_check
(
st
:
StatefulTensor
)
->
bool
:
if
st
.
state
is
not
TensorState
.
COMPUTE
and
st
.
device
.
type
==
'cuda'
:
return
True
return
False
# Here ST means Stateful Tensor
class
BaseSTContainer
(
ABC
):
"""A type of container that store all potential stateful tensors which can be evicted from
CUDA. This kind of stateful tensor should satisfy two conditions. One is that it hasn't been
evicted, meaning the type of its device is CUDA, the other is that it isn't pinned in CUDA
memory, meaning its state isn't COMPUTE.
This container should get a stateful tensor when it become HOLD_LIKE from COMPUTE.
And it pops stateful tensors in function, `evict_tensors`.
In order to acquire an optimal eviction policy, users may need to offer computation step
index of each stateful tensor. So we can use a heap to maintain all potential evictable
statefule tensors. When poping, we can get the stateful tensor that used furthest in
current computation step.
"""
def
__init__
(
self
,
compute_step_dict
:
Dict
[
StatefulTensor
,
List
[
int
]],
total_step
:
int
):
self
.
compute_step_dict
=
compute_step_dict
self
.
total_step
=
total_step
@
abstractmethod
def
empty
(
self
)
->
bool
:
pass
@
abstractmethod
def
create
(
self
,
stateful_tensor_list
:
List
[
StatefulTensor
])
->
None
:
pass
@
abstractmethod
def
push
(
self
,
stateful_tensor
:
StatefulTensor
,
cur_step
:
int
)
->
None
:
pass
@
abstractmethod
def
pop
(
self
)
->
Optional
[
StatefulTensor
]:
pass
class
QueueSTContainer
(
BaseSTContainer
):
"""Queue type stateful tensor container. This is used in 'cpu' tensor placement policy.
It pops potential evictable stateful tensors in FIFO.
"""
def
__init__
(
self
,
compute_step_dict
:
Dict
[
StatefulTensor
,
List
[
int
]],
total_step
:
int
):
super
().
__init__
(
compute_step_dict
,
total_step
)
self
.
container
=
None
def
empty
(
self
)
->
bool
:
assert
self
.
container
is
not
None
return
self
.
container
.
empty
()
def
create
(
self
,
stateful_tensor_list
:
List
[
StatefulTensor
])
->
None
:
self
.
container
=
queue
.
SimpleQueue
()
for
stateful_tensor
in
stateful_tensor_list
:
self
.
container
.
put
(
stateful_tensor
)
def
push
(
self
,
stateful_tensor
:
StatefulTensor
,
cur_step
:
int
)
->
None
:
self
.
container
.
put
(
stateful_tensor
)
def
pop
(
self
)
->
Optional
[
StatefulTensor
]:
ret
=
None
while
not
self
.
empty
():
out_tensor
=
self
.
container
.
get
()
if
evict_check
(
out_tensor
):
ret
=
out_tensor
break
return
ret
class
HeapSTContainer
(
BaseSTContainer
):
"""Heap type stateful tensor container. This is used in 'auto' tensor placement policy.
It pops potential evictable stateful tensors in the order of the distance between current
step and next used step.
"""
def
__init__
(
self
,
compute_step_dict
:
Dict
[
StatefulTensor
,
List
[
int
]],
total_step
:
int
):
super
().
__init__
(
compute_step_dict
,
total_step
)
self
.
container
=
None
def
empty
(
self
)
->
bool
:
assert
self
.
container
is
not
None
return
self
.
container
==
[]
def
create
(
self
,
stateful_tensor_list
:
List
[
StatefulTensor
])
->
None
:
self
.
container
=
[]
for
stateful_tensor
in
stateful_tensor_list
:
# we want to pop the tensor which has the greatest next_step
# so the weight is next_step multiplied by -1
weight
=
-
self
.
__get_next_compute_step
(
stateful_tensor
,
-
1
)
self
.
container
.
append
((
weight
,
stateful_tensor
))
heapq
.
heapify
(
self
.
container
)
def
push
(
self
,
stateful_tensor
:
StatefulTensor
,
cur_step
:
int
)
->
None
:
# we want to pop the tensor which has the greatest next_step
# so the weight is next_step multiplied by -1
weight
=
-
self
.
__get_next_compute_step
(
stateful_tensor
,
cur_step
)
heapq
.
heappush
(
self
.
container
,
(
weight
,
stateful_tensor
))
def
pop
(
self
)
->
Optional
[
StatefulTensor
]:
ret
=
None
while
not
self
.
empty
():
_
,
out_tensor
=
heapq
.
heappop
(
self
.
container
)
if
evict_check
(
out_tensor
):
ret
=
out_tensor
break
return
ret
def
__get_next_compute_step
(
self
,
stateful_tensor
:
StatefulTensor
,
cur_step
:
int
):
# compute the id of next step
# if the tensor is not used in the furture
# next_step is set to the maximum
next_step
=
self
.
total_step
step_list
=
self
.
compute_step_dict
[
stateful_tensor
]
for
step
in
step_list
:
if
step
>
cur_step
:
next_step
=
step
break
return
next_step
colossalai/nn/parallel/data_parallel.py
View file @
5be118f4
...
@@ -3,16 +3,18 @@ import itertools
...
@@ -3,16 +3,18 @@ import itertools
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
functools
import
partial
from
functools
import
partial
from
colossalai.zero.utils.zero_hook_v2
import
ZeROHookV2
from
colossalai.zero.utils.zero_hook_v2
import
ZeROHookV2
from
colossalai.gemini.chunk
import
TensorState
,
Chunk
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Set
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Set
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
colossalai.tensor.colo_parameter
import
ColoParameter
from
colossalai.tensor.colo_parameter
import
ColoParameter
,
ColoTensor
,
ColoTensorSpec
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
.reducer
import
Reducer
from
.reducer
import
Reducer
from
colossalai.gemini.chunk
import
TensorState
,
Chunk
,
ChunkManager
from
colossalai.nn.parallel.utils
import
get_temp_total_chunk_on_cuda
try
:
try
:
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
,
_IncompatibleKeys
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
,
_IncompatibleKeys
except
ImportError
:
except
ImportError
:
...
@@ -208,28 +210,34 @@ class ZeroDDP(ColoDDP):
...
@@ -208,28 +210,34 @@ class ZeroDDP(ColoDDP):
def
__init__
(
self
,
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
module
:
torch
.
nn
.
Module
,
gemini_manager
:
GeminiManager
,
gemini_manager
:
GeminiManager
,
pin_memory
:
bool
=
False
,
force_outputs_fp32
:
bool
=
False
)
->
None
:
force_outputs_fp32
:
bool
=
False
)
->
None
:
super
().
__init__
(
module
,
process_group
=
gemini_manager
.
chunk_manager
.
p
rocess
_g
roup
)
super
().
__init__
(
module
,
process_group
=
ColoP
rocess
G
roup
()
)
self
.
gemini_manager
=
gemini_manager
self
.
gemini_manager
=
gemini_manager
self
.
chunk_manager
=
gemini_manager
.
chunk_manager
self
.
chunk_manager
:
ChunkManager
=
gemini_manager
.
chunk_manager
self
.
force_outputs_fp32
=
force_outputs_fp32
self
.
force_outputs_fp32
=
force_outputs_fp32
self
.
param_op_hook
=
ZeROHookV2
(
gemini_manager
)
self
.
param_op_hook
=
ZeROHookV2
(
gemini_manager
)
self
.
fp32_params
:
List
[
Colo
Paramete
r
]
=
[]
self
.
fp32_params
:
List
[
Colo
Tenso
r
]
=
[]
self
.
overflow_counter
=
0
self
.
overflow_counter
=
0
self
.
grads_device
:
Dict
[
torch
.
Tensor
,
torch
.
device
]
=
{}
self
.
grads_device
:
Dict
[
torch
.
Tensor
,
torch
.
device
]
=
{}
self
.
chunk_manager
.
create_group
(
'fp16_param'
,
force_data_on_cuda
=
True
)
self
.
chunk_manager
.
create_group
(
'fp32_param'
)
# TODO: get param order and filter unused params
# TODO: get param order and filter unused params
for
p
in
module
.
parameters
():
for
p
in
module
.
parameters
():
assert
isinstance
(
p
,
ColoParameter
)
if
getattr
(
p
,
'_ddp_to_ignore'
,
False
):
if
getattr
(
p
,
'_ddp_to_ignore'
,
False
):
p
.
data
=
p
.
half
()
p
.
data
=
p
.
half
()
continue
continue
fp32_p
=
p
.
float
().
detach
()
dp_world_size
=
p
.
process_group
.
dp_world_size
()
fp32_data
=
p
.
float
().
data
p
.
data
=
p
.
half
()
p
.
data
=
p
.
half
()
self
.
chunk_manager
.
append_tensor
(
p
,
'fp16_param'
)
fp32_p
=
ColoTensor
(
fp32_data
,
spec
=
ColoTensorSpec
(
p
.
process_group
))
self
.
chunk_manager
.
append_tensor
(
fp32_p
,
'fp32_param'
)
self
.
chunk_manager
.
append_tensor
(
p
,
'fp16_param'
,
dp_world_size
,
pin_memory
)
self
.
chunk_manager
.
append_tensor
(
fp32_p
,
'fp32_param'
,
dp_world_size
,
pin_memory
)
self
.
fp32_params
.
append
(
fp32_p
)
self
.
fp32_params
.
append
(
fp32_p
)
self
.
grads_device
[
p
]
=
self
.
gemini_manager
.
default_device
self
.
grads_device
[
p
]
=
self
.
gemini_manager
.
default_device
self
.
chunk_manager
.
close_all_groups
()
self
.
_cast_buffers
()
self
.
_cast_buffers
()
self
.
_logger
=
get_dist_logger
()
self
.
_logger
=
get_dist_logger
()
...
@@ -248,10 +256,7 @@ class ZeroDDP(ColoDDP):
...
@@ -248,10 +256,7 @@ class ZeroDDP(ColoDDP):
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
if
getattr
(
p
,
'_ddp_to_ignore'
,
False
):
if
getattr
(
p
,
'_ddp_to_ignore'
,
False
):
continue
continue
if
self
.
chunk_manager
.
get_chunk
(
p
).
is_empty
or
not
p
.
requires_grad
:
p
.
grad
=
None
p
.
grad
=
None
else
:
p
.
grad
=
p
.
data
def
_post_backward
(
self
):
def
_post_backward
(
self
):
self
.
chunk_manager
.
exec_lazy_release
()
self
.
chunk_manager
.
exec_lazy_release
()
...
@@ -276,21 +281,22 @@ class ZeroDDP(ColoDDP):
...
@@ -276,21 +281,22 @@ class ZeroDDP(ColoDDP):
free_storage
(
empty_grad
)
free_storage
(
empty_grad
)
with
torch
.
_C
.
DisableTorchFunction
():
with
torch
.
_C
.
DisableTorchFunction
():
self
.
chunk_manager
.
trans_tensor_state
(
p
,
TensorState
.
READY_FOR_REDUCE
)
self
.
chunk_manager
.
trans_tensor_state
(
p
,
TensorState
.
READY_FOR_REDUCE
)
if
self
.
dp_world_size
>
1
:
grad
=
grad
/
self
.
dp_world_size
self
.
chunk_manager
.
copy_tensor_to_chunk_slice
(
p
,
grad
)
chunk
=
self
.
chunk_manager
.
get_chunk
(
p
)
chunk
=
self
.
chunk_manager
.
get_chunk
(
p
)
chunk
.
copy_tensor_to_chunk_slice
(
p
,
grad
)
reduced
=
self
.
chunk_manager
.
reduce_chunk
(
chunk
)
reduced
=
self
.
chunk_manager
.
reduce_chunk
(
chunk
)
self
.
chunk_manager
.
release_chunk
(
chunk
)
if
reduced
:
if
reduced
and
not
chunk
.
is_empty
:
if
chunk
.
is_gathered
:
chunk
.
chunk_total
.
div_
(
chunk
.
pg_size
)
else
:
chunk
.
cuda_shard
.
div_
(
chunk
.
pg_size
)
self
.
overflow_counter
+=
chunk
.
has_inf_or_nan
self
.
overflow_counter
+=
chunk
.
has_inf_or_nan
self
.
chunk_manager
.
move_chunk
(
chunk
,
self
.
grads_device
[
p
])
self
.
chunk_manager
.
move_chunk
(
chunk
,
self
.
grads_device
[
p
]
,
force_copy
=
True
)
return
empty_grad
return
empty_grad
def
zero_grad
(
self
,
set_to_none
:
bool
=
False
)
->
None
:
def
zero_grad
(
self
,
set_to_none
:
bool
=
False
)
->
None
:
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
def
_
set_chunk_grad_device
(
self
,
chunk
:
Chunk
,
device
:
torch
.
device
)
->
None
:
def
set_chunk_grad_device
(
self
,
chunk
:
Chunk
,
device
:
torch
.
device
)
->
None
:
for
tensor
in
chunk
.
get_tensors
():
for
tensor
in
chunk
.
get_tensors
():
self
.
grads_device
[
tensor
]
=
device
self
.
grads_device
[
tensor
]
=
device
...
@@ -311,14 +317,11 @@ class ZeroDDP(ColoDDP):
...
@@ -311,14 +317,11 @@ class ZeroDDP(ColoDDP):
['bias', 'weight']
['bias', 'weight']
"""
"""
is_rank_0
=
self
.
chunk_manager
.
process_group
.
dp_local_rank
()
==
0
record_flag
=
(
not
only_rank_0
)
or
is_rank_0
if
destination
is
None
:
if
destination
is
None
:
destination
=
OrderedDict
()
destination
=
OrderedDict
()
destination
.
_metadata
=
OrderedDict
()
destination
.
_metadata
=
OrderedDict
()
destination
.
_metadata
[
prefix
[:
-
1
]]
=
local_metadata
=
dict
(
version
=
self
.
_version
)
destination
.
_metadata
[
prefix
[:
-
1
]]
=
local_metadata
=
dict
(
version
=
self
.
_version
)
self
.
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
,
record_flag
)
self
.
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
,
only_rank_0
)
for
hook
in
self
.
_state_dict_hooks
.
values
():
for
hook
in
self
.
_state_dict_hooks
.
values
():
hook_result
=
hook
(
self
,
destination
,
prefix
,
local_metadata
)
hook_result
=
hook
(
self
,
destination
,
prefix
,
local_metadata
)
...
@@ -326,7 +329,7 @@ class ZeroDDP(ColoDDP):
...
@@ -326,7 +329,7 @@ class ZeroDDP(ColoDDP):
destination
=
hook_result
destination
=
hook_result
return
destination
return
destination
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
,
record_flag
:
bool
=
True
):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
,
only_rank_0
=
True
):
r
"""Saves module state to `destination` dictionary, containing a state
r
"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
submodule in :meth:`~torch.nn.Module.state_dict`.
...
@@ -339,30 +342,30 @@ class ZeroDDP(ColoDDP):
...
@@ -339,30 +342,30 @@ class ZeroDDP(ColoDDP):
prefix (str): the prefix for parameters and buffers used in this
prefix (str): the prefix for parameters and buffers used in this
module
module
"""
"""
assert
keep_vars
is
False
,
"`state_dict` with parameter, `keep_vars=True`, is not supported now."
# save parameters
# save parameters
param_to_save_data
=
dict
()
param_to_save_data
=
dict
()
chunk_list
=
self
.
chunk_manager
.
get_chunks
(
self
.
fp32_params
)
chunk_list
=
self
.
chunk_manager
.
get_chunks
(
self
.
fp32_params
)
for
chunk
in
chunk_list
:
for
chunk
in
chunk_list
:
# record the original device of the chunk
temp_chunk
=
get_temp_total_chunk_on_cuda
(
chunk
)
org_chunk_dev_typ
=
chunk
.
device_type
self
.
chunk_manager
.
access_chunk
(
chunk
)
for
tensor
in
chunk
.
get_tensors
():
for
tensor
,
tensor_info
in
chunk
.
tensors_info
.
items
():
rec_p
=
torch
.
empty
([
0
])
record_tensor
=
torch
.
empty
([
0
])
record_flag
=
(
not
only_rank_0
)
|
(
dist
.
get_rank
(
chunk
.
torch_pg
)
==
0
)
if
record_flag
:
if
record_flag
:
rec_p
=
tensor
.
cpu
()
# move the whole tensor to CPU mem
record_tensor
=
temp_chunk
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
).
cpu
()
assert
tensor
not
in
param_to_save_data
assert
tensor
not
in
param_to_save_data
param_to_save_data
[
tensor
]
=
rec_p
param_to_save_data
[
tensor
]
=
record_tensor
# release the actual memory of the chunk
self
.
chunk_manager
.
release_chunk
(
chunk
)
del
temp_chunk
if
not
chunk
.
is_empty
and
org_chunk_dev_typ
==
'cpu'
:
self
.
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
))
for
(
name
,
p
),
fp32_p
in
zip
(
self
.
named_parameters
(),
self
.
fp32_params
):
for
(
name
,
p
),
fp32_p
in
zip
(
self
.
named_parameters
(),
self
.
fp32_params
):
if
p
is
not
None
:
if
p
is
not
None
:
assert
fp32_p
in
param_to_save_data
,
"Parameter '{}' is neglected in the chunk list"
.
format
(
name
)
assert
fp32_p
in
param_to_save_data
,
"Parameter '{}' is neglected in the chunk list"
.
format
(
name
)
rec
_p
=
param_to_save_data
[
fp32_p
]
rec
ord_parameter
=
param_to_save_data
[
fp32_p
]
destination
[
prefix
+
name
]
=
rec
_p
if
keep_vars
else
rec_p
.
detach
()
destination
[
prefix
+
name
]
=
rec
ord_parameter
# save all buffers
# save all buffers
for
name
,
buf
in
self
.
named_buffers
():
for
name
,
buf
in
self
.
named_buffers
():
...
@@ -466,40 +469,61 @@ class ZeroDDP(ColoDDP):
...
@@ -466,40 +469,61 @@ class ZeroDDP(ColoDDP):
local_name_params
=
itertools
.
chain
(
self
.
named_parameters
(),
persistent_buffers
.
items
())
local_name_params
=
itertools
.
chain
(
self
.
named_parameters
(),
persistent_buffers
.
items
())
local_state
=
{
k
:
v
for
k
,
v
in
local_name_params
if
v
is
not
None
}
local_state
=
{
k
:
v
for
k
,
v
in
local_name_params
if
v
is
not
None
}
def
load
(
name
,
dest_tensor
,
copy_func
):
def
load
(
param_
name
,
dest_tensor
,
copy_func
):
key
=
prefix
+
name
state_
key
=
prefix
+
param_
name
if
key
in
state_dict
:
if
state_
key
in
state_dict
:
input_param
=
state_dict
[
key
]
input_param
=
state_dict
[
state_
key
]
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if
len
(
dest_tensor
.
shape
)
==
0
and
len
(
input_param
.
shape
)
==
1
:
if
len
(
dest_tensor
.
shape
)
==
0
and
len
(
input_param
.
shape
)
==
1
:
input_param
=
input_param
[
0
]
input_param
=
input_param
[
0
]
if
input_param
.
shape
!=
dest_tensor
.
shape
:
if
input_param
.
shape
!=
dest_tensor
.
shape
:
# local shape should match the one in checkpoint
# local shape should match the one in checkpoint
error_msgs
.
append
(
'size mismatch for {}: copying a param with shape {} from checkpoint, '
error_msgs
.
append
(
'size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.
format
(
key
,
input_param
.
shape
,
'the shape in current model is {}.'
.
format
(
state_
key
,
input_param
.
shape
,
dest_tensor
.
shape
))
dest_tensor
.
shape
))
return
return
try
:
try
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, input_param)
copy_func
(
input_param
)
copy_func
(
input_param
)
except
Exception
as
ex
:
except
Exception
as
ex
:
error_msgs
.
append
(
'While copying the parameter named "{}", '
error_msgs
.
append
(
'While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'
.
format
(
key
,
dest_tensor
.
size
(),
input_param
.
size
(),
'an exception occurred : {}.'
.
format
(
state_
key
,
dest_tensor
.
size
(),
ex
.
args
))
input_param
.
size
(),
ex
.
args
))
elif
strict
:
elif
strict
:
missing_keys
.
append
(
key
)
missing_keys
.
append
(
state_
key
)
def
load_fp32_p
(
fp32_p
,
data
):
def
load_fp32_parameter
(
chunk_slice
,
data
):
if
fp32_p
.
storage
().
size
()
>
0
:
chunk_slice
.
copy_
(
data
.
flatten
())
self
.
chunk_manager
.
copy_tensor_to_chunk_slice
(
fp32_p
,
data
)
fp32_to_name
=
dict
()
for
(
name
,
p
),
fp32_p
in
zip
(
self
.
named_parameters
(),
self
.
fp32_params
):
for
(
name
,
p
),
fp32_p
in
zip
(
self
.
named_parameters
(),
self
.
fp32_params
):
if
p
is
not
None
:
if
p
is
not
None
:
load
(
name
,
fp32_p
,
partial
(
load_fp32_p
,
fp32_p
))
fp32_to_name
[
fp32_p
]
=
name
self
.
chunk_manager
.
copy_chunk_group
(
'fp16_param'
,
'fp32_param'
)
chunk_list
=
self
.
chunk_manager
.
get_chunks
(
self
.
fp32_params
)
for
chunk
in
chunk_list
:
temp_chunk
=
get_temp_total_chunk_on_cuda
(
chunk
)
for
tensor
,
tensor_info
in
chunk
.
tensors_info
.
items
():
parameter_name
=
fp32_to_name
[
tensor
]
parameter_slice
=
temp_chunk
[
tensor_info
.
offset
:
tensor_info
.
end
]
load
(
parameter_name
,
tensor
,
partial
(
load_fp32_parameter
,
parameter_slice
))
if
chunk
.
is_gathered
:
chunk
.
chunk_total
.
copy_
(
temp_chunk
)
elif
chunk
.
cuda_shard
is
not
None
:
chunk
.
cuda_shard
.
copy_
(
temp_chunk
[
chunk
.
shard_begin
:
chunk
.
shard_end
])
else
:
chunk
.
cpu_shard
.
copy_
(
temp_chunk
[
chunk
.
shard_begin
:
chunk
.
shard_end
])
del
temp_chunk
for
chunk_32
in
chunk_list
:
chunk_16
=
chunk_32
.
paired_chunk
assert
chunk_16
is
not
None
chunk_16
.
optim_update
()
for
name
,
buf
in
persistent_buffers
.
items
():
for
name
,
buf
in
persistent_buffers
.
items
():
if
buf
is
not
None
:
if
buf
is
not
None
:
...
...
colossalai/nn/parallel/utils.py
0 → 100644
View file @
5be118f4
import
torch
import
torch.distributed
as
dist
from
colossalai.gemini.chunk
import
Chunk
from
colossalai.utils
import
get_current_device
def
get_temp_total_chunk_on_cuda
(
chunk
:
Chunk
):
if
chunk
.
is_gathered
:
return
chunk
.
chunk_total
if
chunk
.
cuda_shard
is
not
None
:
shard_temp
=
chunk
.
cuda_shard
else
:
shard_temp
=
chunk
.
cpu_shard
.
to
(
get_current_device
())
total_temp
=
torch
.
zeros
(
chunk
.
chunk_size
,
dtype
=
chunk
.
dtype
,
device
=
get_current_device
())
gather_list
=
list
(
torch
.
chunk
(
input
=
total_temp
,
chunks
=
chunk
.
pg_size
,
dim
=
0
))
dist
.
all_gather
(
tensor_list
=
gather_list
,
tensor
=
shard_temp
,
group
=
chunk
.
torch_pg
)
return
total_temp
colossalai/zero/utils/zero_hook_v2.py
View file @
5be118f4
...
@@ -54,8 +54,8 @@ class ZeROHookV2(ParamOpHook):
...
@@ -54,8 +54,8 @@ class ZeROHookV2(ParamOpHook):
@
contextmanager
@
contextmanager
def
switch_training_phase
(
self
,
training_phase
:
TrainingPhase
=
TrainingPhase
.
BACKWARD
):
def
switch_training_phase
(
self
,
training_phase
:
TrainingPhase
=
TrainingPhase
.
BACKWARD
):
old_training_phase
=
self
.
_training_phase
try
:
try
:
old_training_phase
=
self
.
_training_phase
self
.
_training_phase
=
training_phase
self
.
_training_phase
=
training_phase
yield
yield
finally
:
finally
:
...
...
colossalai/zero/zero_optimizer.py
View file @
5be118f4
...
@@ -2,17 +2,14 @@ import torch
...
@@ -2,17 +2,14 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
enum
import
Enum
from
enum
import
Enum
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
torch.nn
import
Parameter
from
colossalai.nn.parallel.data_parallel
import
ZeroDDP
from
colossalai.nn.parallel.data_parallel
import
ZeroDDP
from
typing
import
Dict
from
typing
import
Dict
,
Tuple
,
Set
from
colossalai.amp.naive_amp.grad_scaler
import
DynamicGradScaler
from
colossalai.amp.naive_amp.grad_scaler
import
DynamicGradScaler
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils
import
get_current_device
,
disposable
from
colossalai.utils
import
get_current_device
,
disposable
from
colossalai.utils.common
import
_compute_grad_lp
,
compute_grad_norm
,
_clip_grad_norm
from
colossalai.gemini.chunk
import
Chunk
,
ChunkManager
from
collections
import
defaultdict
,
abc
as
container_abcs
from
copy
import
deepcopy
from
itertools
import
chain
from
torch._six
import
inf
class
OptimState
(
Enum
):
class
OptimState
(
Enum
):
...
@@ -33,8 +30,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -33,8 +30,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
Args:
Args:
optim (Optimizer): An Optimizer instance.
optim (Optimizer): An Optimizer instance.
module (ZeroDDP): A ``ZeroDDP`` instance.
module (ZeroDDP): A ``ZeroDDP`` instance.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer.
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
Defaults to 0.0.
Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
...
@@ -61,11 +58,20 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -61,11 +58,20 @@ class ZeroOptimizer(ColossalaiOptimizer):
assert
isinstance
(
module
,
ZeroDDP
)
assert
isinstance
(
module
,
ZeroDDP
)
self
.
module
=
module
self
.
module
=
module
self
.
gemini_manager
=
module
.
gemini_manager
self
.
gemini_manager
=
module
.
gemini_manager
self
.
chunk_manager
=
self
.
gemini_manager
.
chunk_manager
self
.
chunk_manager
:
ChunkManager
=
self
.
gemini_manager
.
chunk_manager
self
.
optim_state
=
OptimState
.
UNSCALED
self
.
optim_state
=
OptimState
.
UNSCALED
self
.
fp16_param_to_fp32_param
:
Dict
[
torch
.
Tensor
,
torch
.
Tensor
]
=
{}
self
.
param_to_range
:
Dict
[
Parameter
,
Tuple
[
int
,
int
]]
=
dict
()
self
.
param_to_chunk32
:
Dict
[
Parameter
,
Chunk
]
=
dict
()
self
.
chunk16_set
:
Set
[
Chunk
]
=
set
()
for
p
,
fp32_p
in
zip
(
module
.
parameters
(),
module
.
fp32_params
):
for
p
,
fp32_p
in
zip
(
module
.
parameters
(),
module
.
fp32_params
):
self
.
fp16_param_to_fp32_param
[
p
]
=
fp32_p
chunk_16
=
self
.
chunk_manager
.
get_chunk
(
p
)
chunk_32
=
self
.
chunk_manager
.
get_chunk
(
fp32_p
)
chunk_32
.
init_pair
(
chunk_16
)
if
chunk_16
not
in
self
.
chunk16_set
:
self
.
chunk16_set
.
add
(
chunk_16
)
self
.
__init__optimizer
()
# Grad scaler
# Grad scaler
self
.
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
initial_scale
,
self
.
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
initial_scale
,
...
@@ -75,7 +81,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -75,7 +81,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
growth_interval
=
growth_interval
,
growth_interval
=
growth_interval
,
hysteresis
=
hysteresis
,
hysteresis
=
hysteresis
,
max_scale
=
max_scale
)
max_scale
=
max_scale
)
self
.
_found_overflow
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
self
.
_found_overflow
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
,
device
=
get_
current_device
())
self
.
_logger
=
get_dist_logger
()
self
.
_logger
=
get_dist_logger
()
self
.
gpu_margin_mem_ratio
:
float
=
float
(
gpu_margin_mem_ratio
)
self
.
gpu_margin_mem_ratio
:
float
=
float
(
gpu_margin_mem_ratio
)
...
@@ -90,16 +96,26 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -90,16 +96,26 @@ class ZeroOptimizer(ColossalaiOptimizer):
self
.
_register_states
=
disposable
(
self
.
_register_states_
)
self
.
_register_states
=
disposable
(
self
.
_register_states_
)
def
_update_params_ptr
(
self
):
def
_set_grad_ptr
(
self
):
for
group
in
self
.
optim
.
param_groups
:
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
for
fake_param
in
group
[
'params'
]:
if
not
self
.
module
.
chunk_manager
.
get_chunk
(
p
).
is_empty
:
chunk32
=
self
.
param_to_chunk32
[
fake_param
]
p
.
data
=
self
.
fp16_param_to_fp32_param
[
p
]
begin
,
end
=
self
.
param_to_range
[
fake_param
]
else
:
chunk16
=
chunk32
.
paired_chunk
assert
p
.
grad
is
None
fake_param
.
data
=
chunk16
.
payload
[
begin
:
end
]
fake_param
.
grad
=
fake_param
.
data
fake_param
.
data
=
chunk32
.
payload
[
begin
:
end
]
def
_update_fp16_params
(
self
):
def
_update_fp16_params
(
self
):
self
.
module
.
chunk_manager
.
copy_chunk_group
(
'fp16_param'
,
'fp32_param'
)
none_tensor
=
torch
.
empty
([
0
])
for
group
in
self
.
param_groups
:
for
fake_param
in
group
[
'params'
]:
assert
fake_param
.
grad
is
None
fake_param
.
data
=
none_tensor
for
chunk16
in
self
.
chunk16_set
:
chunk16
.
optim_update
()
def
_check_overflow
(
self
):
def
_check_overflow
(
self
):
# clear previous overflow record
# clear previous overflow record
...
@@ -128,6 +144,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -128,6 +144,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def
step
(
self
,
*
args
,
**
kwargs
):
def
step
(
self
,
*
args
,
**
kwargs
):
self
.
_maybe_move_fp32_params
()
self
.
_maybe_move_fp32_params
()
self
.
_set_grad_ptr
()
# unscale grads if scaled
# unscale grads if scaled
if
self
.
optim_state
==
OptimState
.
SCALED
:
if
self
.
optim_state
==
OptimState
.
SCALED
:
self
.
_unscale_grads
()
self
.
_unscale_grads
()
...
@@ -138,45 +155,14 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -138,45 +155,14 @@ class ZeroOptimizer(ColossalaiOptimizer):
self
.
zero_grad
()
self
.
zero_grad
()
self
.
_update_fp16_params
()
self
.
_update_fp16_params
()
return
return
self
.
_update_params_ptr
()
ret
=
self
.
optim
.
step
(
*
args
,
**
kwargs
)
ret
=
self
.
optim
.
step
(
*
args
,
**
kwargs
)
self
.
_register_states
()
self
.
_register_states
()
self
.
zero_grad
()
self
.
zero_grad
()
self
.
_update_fp16_params
()
self
.
_update_fp16_params
()
return
ret
return
ret
def
compute_grad_norm
(
self
,
norm_type
:
float
=
2.0
)
->
float
:
norm_type
=
float
(
norm_type
)
if
not
self
.
chunk_manager
.
enable_distributed_storage
:
return
compute_grad_norm
(
self
.
module
.
parameters
(),
norm_type
)
non_distributed_params
=
[]
distributed_params
=
[]
for
p
in
self
.
module
.
parameters
():
if
getattr
(
p
,
'_ddp_to_ignore'
,
False
):
non_distributed_params
.
append
(
p
)
else
:
distributed_params
.
append
(
p
)
non_distributed_norm
=
_compute_grad_lp
(
non_distributed_params
,
norm_type
)
distributed_norm_tensor
=
torch
.
tensor
([
_compute_grad_lp
(
distributed_params
,
norm_type
)],
device
=
get_current_device
())
if
norm_type
==
inf
:
dist
.
all_reduce
(
distributed_norm_tensor
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
chunk_manager
.
process_group
.
dp_process_group
())
total_norm
=
max
(
non_distributed_norm
,
distributed_norm_tensor
.
item
())
else
:
dist
.
all_reduce
(
distributed_norm_tensor
,
group
=
self
.
chunk_manager
.
process_group
.
dp_process_group
())
total_norm
=
non_distributed_norm
+
distributed_norm_tensor
.
item
()
total_norm
=
total_norm
**
(
1
/
norm_type
)
return
total_norm
def
clip_grad_norm
(
self
,
model
:
torch
.
nn
.
Module
,
max_norm
:
float
,
norm_type
:
float
=
2.0
):
def
clip_grad_norm
(
self
,
model
:
torch
.
nn
.
Module
,
max_norm
:
float
,
norm_type
:
float
=
2.0
):
if
self
.
optim_state
==
OptimState
.
SCALED
:
raise
NotImplementedError
self
.
_unscale_grads
()
total_norm
=
self
.
compute_grad_norm
(
norm_type
)
_clip_grad_norm
(
self
.
module
.
parameters
(),
max_norm
,
total_norm
)
return
total_norm
def
backward
(
self
,
loss
:
torch
.
Tensor
):
def
backward
(
self
,
loss
:
torch
.
Tensor
):
loss
=
self
.
loss_scale
*
loss
loss
=
self
.
loss_scale
*
loss
...
@@ -197,24 +183,31 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -197,24 +183,31 @@ class ZeroOptimizer(ColossalaiOptimizer):
available_cuda_margin_mem
=
self
.
gemini_manager
.
cuda_margin_mem
*
self
.
gpu_margin_mem_ratio
available_cuda_margin_mem
=
self
.
gemini_manager
.
cuda_margin_mem
*
self
.
gpu_margin_mem_ratio
fp32_params_available_cuda_margin_mem
=
available_cuda_margin_mem
/
self
.
optim
.
num_fp32_shards_per_param
fp32_params_available_cuda_margin_mem
=
available_cuda_margin_mem
/
self
.
optim
.
num_fp32_shards_per_param
fp32_params_used_cuda_margin_mem
=
0
fp32_params_used_cuda_margin_mem
=
0
for
fp16_param_chunk
,
fp32_param_chunk
in
zip
(
self
.
chunk_manager
.
chunk_groups
[
'fp16_param'
],
self
.
chunk_manager
.
chunk_groups
[
'fp32_param'
]):
for
group
in
self
.
param_groups
:
if
fp32_param_chunk
.
is_empty
:
for
fake_param
in
group
[
'params'
]:
continue
chunk32
=
self
.
param_to_chunk32
[
fake_param
]
if
fp32_params_used_cuda_margin_mem
+
fp32_param_chunk
.
mem
<
fp32_params_available_cuda_margin_mem
:
chunk16
=
chunk32
.
paired_chunk
self
.
chunk_manager
.
move_chunk
(
fp32_param_chunk
,
get_current_device
())
# stores grad now
if
chunk32
.
device_type
==
'cuda'
:
self
.
chunk_manager
.
move_chunk
(
fp16_param_chunk
,
get_current_device
())
continue
self
.
module
.
_set_chunk_grad_device
(
fp16_param_chunk
,
get_current_device
())
fp32_params_used_cuda_margin_mem
+=
fp32_param_chunk
.
mem
if
fp32_params_used_cuda_margin_mem
+
chunk32
.
payload_mem
<
fp32_params_available_cuda_margin_mem
:
for
p
in
fp16_param_chunk
.
get_tensors
():
self
.
chunk_manager
.
move_chunk
(
chunk32
,
get_current_device
())
state
=
self
.
optim
.
state
[
p
]
# stores grad now
self
.
chunk_manager
.
move_chunk
(
chunk16
,
get_current_device
())
self
.
module
.
set_chunk_grad_device
(
chunk16
,
get_current_device
())
fp32_params_used_cuda_margin_mem
+=
chunk32
.
payload_mem
for
group
in
self
.
param_groups
:
for
fake_param
in
group
[
'params'
]:
chunk32
=
self
.
param_to_chunk32
[
fake_param
]
if
chunk32
.
device_type
==
'cuda'
:
state
=
self
.
optim
.
state
[
fake_param
]
for
k
,
v
in
state
.
items
():
for
k
,
v
in
state
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
if
isinstance
(
v
,
torch
.
Tensor
):
state
[
k
]
=
v
.
to
(
get_current_device
())
state
[
k
]
=
v
.
to
(
get_current_device
())
self
.
module
.
_setup_grads_ptr
()
def
_register_states_
(
self
):
def
_register_states_
(
self
):
for
group
in
self
.
optim
.
param_groups
:
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
...
@@ -223,110 +216,27 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -223,110 +216,27 @@ class ZeroOptimizer(ColossalaiOptimizer):
if
isinstance
(
val
,
torch
.
Tensor
):
if
isinstance
(
val
,
torch
.
Tensor
):
self
.
chunk_manager
.
add_extern_static_tensor
(
val
)
self
.
chunk_manager
.
add_extern_static_tensor
(
val
)
def
state_dict
(
self
,
only_rank_0
:
bool
=
True
):
def
__init__optimizer
(
self
):
r
"""Returns the state of the optimizer as a :class:`dict`. If only_rank_0 is True, for DP rank != 0, this function returns None.
This saves memory usage.
It contains two entries:
def
get_range_pair
(
local_chunk
:
Chunk
,
local_param
:
Parameter
):
param_info
=
local_chunk
.
tensors_info
[
local_param
]
begin
=
max
(
0
,
param_info
.
offset
-
local_chunk
.
shard_begin
)
end
=
min
(
local_chunk
.
shard_size
,
param_info
.
end
-
local_chunk
.
shard_begin
)
return
begin
,
end
* state - a dict holding current optimization state. Its content
for
group
in
self
.
optim
.
param_groups
:
differs between optimizer classes.
fake_params_list
=
list
()
* param_groups - a list containing all parameter groups where each
parameter group is a dict
for
param
in
group
[
'params'
]:
"""
chunk16
=
self
.
chunk_manager
.
get_chunk
(
param
)
is_rank_0
=
self
.
chunk_manager
.
process_group
.
dp_local_rank
()
==
0
range_pair
=
get_range_pair
(
chunk16
,
param
)
if
not
self
.
chunk_manager
.
enable_distributed_storage
and
only_rank_0
and
not
is_rank_0
:
if
range_pair
[
0
]
>=
range_pair
[
1
]:
return
continue
optim_state_dict
=
super
().
state_dict
()
scaler_state_dict
=
self
.
grad_scaler
.
state_dict
()
fake_param
=
torch
.
nn
.
Parameter
(
torch
.
empty
([
0
]))
optim_state_dict
[
'scaler'
]
=
scaler_state_dict
self
.
param_to_chunk32
[
fake_param
]
=
chunk16
.
paired_chunk
if
not
self
.
chunk_manager
.
enable_distributed_storage
:
self
.
param_to_range
[
fake_param
]
=
range_pair
return
optim_state_dict
local_state
=
{
k
:
convert_state_dict_to_cpu
(
v
)
for
k
,
v
in
optim_state_dict
[
'state'
].
items
()
if
len
(
v
)
>
0
}
fake_params_list
.
append
(
fake_param
)
if
not
self
.
chunk_manager
.
process_group
.
has_cpu_groups
:
self
.
chunk_manager
.
process_group
.
set_cpu_groups
()
group
[
'params'
]
=
fake_params_list
output
=
[
None
for
_
in
range
(
self
.
chunk_manager
.
process_group
.
dp_world_size
())]
if
only_rank_0
:
dst_rank
=
self
.
chunk_manager
.
process_group
.
dp_rank_list
()[
0
]
dist
.
gather_object
(
local_state
,
output
if
self
.
chunk_manager
.
process_group
.
dp_local_rank
()
==
0
else
None
,
dst
=
dst_rank
,
group
=
self
.
chunk_manager
.
process_group
.
cpu_dp_process_group
())
if
not
is_rank_0
:
return
else
:
dist
.
all_gather_object
(
output
,
local_state
,
group
=
self
.
chunk_manager
.
process_group
.
cpu_dp_process_group
())
for
state
in
output
:
optim_state_dict
[
'state'
].
update
(
state
)
return
optim_state_dict
def
load_state_dict
(
self
,
state_dict
):
r
"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
if
'scaler'
not
in
state_dict
:
self
.
_logger
.
warning
(
'Missing scaler when loading optimizer state dict'
,
ranks
=
[
0
])
else
:
self
.
grad_scaler
.
load_state_dict
(
deepcopy
(
state_dict
[
'scaler'
]))
# Validate the state_dict
groups
=
self
.
param_groups
saved_groups
=
deepcopy
(
state_dict
[
'param_groups'
])
if
len
(
groups
)
!=
len
(
saved_groups
):
raise
ValueError
(
"loaded state dict has a different number of "
"parameter groups"
)
param_lens
=
(
len
(
g
[
'params'
])
for
g
in
groups
)
saved_lens
=
(
len
(
g
[
'params'
])
for
g
in
saved_groups
)
if
any
(
p_len
!=
s_len
for
p_len
,
s_len
in
zip
(
param_lens
,
saved_lens
)):
raise
ValueError
(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)
# Update the state
id_map
=
{
old_id
:
p
for
old_id
,
p
in
zip
(
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
saved_groups
)),
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
groups
)))
}
def
cast
(
param
,
value
):
r
"""Make a deep copy of value, casting all tensors to device of param."""
if
isinstance
(
value
,
torch
.
Tensor
):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if
param
.
is_floating_point
():
value
=
value
.
to
(
param
.
dtype
)
value
=
value
.
to
(
param
.
device
)
return
value
elif
isinstance
(
value
,
dict
):
return
{
k
:
cast
(
param
,
v
)
for
k
,
v
in
value
.
items
()}
elif
isinstance
(
value
,
container_abcs
.
Iterable
):
return
type
(
value
)(
cast
(
param
,
v
)
for
v
in
value
)
else
:
return
value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state
=
defaultdict
(
dict
)
for
k
,
v
in
state_dict
[
'state'
].
items
():
if
k
in
id_map
:
param
=
self
.
fp16_param_to_fp32_param
[
id_map
[
k
]]
if
param
.
storage
().
size
()
>
0
:
state
[
param
]
=
cast
(
param
,
deepcopy
(
v
))
else
:
state
[
k
]
=
deepcopy
(
v
)
# Update parameter groups, setting their 'params' value
def
update_group
(
group
,
new_group
):
new_group
[
'params'
]
=
group
[
'params'
]
return
new_group
param_groups
=
[
update_group
(
g
,
ng
)
for
g
,
ng
in
zip
(
groups
,
saved_groups
)]
self
.
__setstate__
({
'state'
:
state
,
'param_groups'
:
param_groups
})
def
convert_state_dict_to_cpu
(
state
:
Dict
[
str
,
torch
.
Tensor
]):
return
{
k
:
v
.
cpu
()
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
state
.
items
()}
tests/test_ddp/test_ddp_ignore_params.py
View file @
5be118f4
...
@@ -6,11 +6,11 @@ from colossalai.testing import rerun_if_address_is_in_use
...
@@ -6,11 +6,11 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.gemini
import
ChunkManager
from
colossalai.gemini
.chunk
import
ChunkManager
,
search_chunk_configuration
from
functools
import
partial
from
functools
import
partial
from
colossalai.nn.parallel
import
ColoDDP
,
ZeroDDP
from
colossalai.nn.parallel
import
ColoDDP
,
ZeroDDP
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
typing
import
Callable
from
typing
import
Callable
,
Type
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
os
import
os
import
random
import
random
...
@@ -32,10 +32,9 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
...
@@ -32,10 +32,9 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
return
ColoDDP
(
module
,
process_group
=
pg
)
return
ColoDDP
(
module
,
process_group
=
pg
)
def
init_ddpv2
(
module
:
torch
.
nn
.
Module
,
use_chunk
:
bool
=
False
)
->
ZeroDDP
:
def
init_ddpv2
(
module
:
torch
.
nn
.
Module
)
->
ZeroDDP
:
pg
=
ProcessGroup
()
chunk_config
=
search_chunk_configuration
(
module
,
4
,
1024
)
chunk_size
=
ChunkManager
.
search_chunk_size
(
module
,
64
,
2
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_config
)
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
)
gemini_manager
=
GeminiManager
(
'cuda'
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
'cuda'
,
chunk_manager
)
return
ZeroDDP
(
module
,
gemini_manager
)
return
ZeroDDP
(
module
,
gemini_manager
)
...
@@ -51,7 +50,7 @@ class Net(torch.nn.Module):
...
@@ -51,7 +50,7 @@ class Net(torch.nn.Module):
return
self
.
fc2
(
self
.
fc1
(
x
))
return
self
.
fc2
(
self
.
fc1
(
x
))
def
run_fwd_bwd
(
ddp_cls
:
ColoDDP
,
init_ddp_func
:
Callable
[[
torch
.
nn
.
Module
],
ColoDDP
]):
def
run_fwd_bwd
(
ddp_cls
:
Type
[
ColoDDP
]
,
init_ddp_func
:
Callable
[[
torch
.
nn
.
Module
],
ColoDDP
]):
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
Net
().
cuda
()
model
=
Net
().
cuda
()
w1
=
model
.
fc1
.
weight
w1
=
model
.
fc1
.
weight
...
@@ -62,8 +61,14 @@ def run_fwd_bwd(ddp_cls: ColoDDP, init_ddp_func: Callable[[torch.nn.Module], Col
...
@@ -62,8 +61,14 @@ def run_fwd_bwd(ddp_cls: ColoDDP, init_ddp_func: Callable[[torch.nn.Module], Col
logits
=
model
(
x
)
logits
=
model
(
x
)
loss
=
torch
.
sum
(
logits
)
loss
=
torch
.
sum
(
logits
)
model
.
backward
(
loss
)
model
.
backward
(
loss
)
if
ddp_cls
is
ZeroDDP
:
w1s_grad
=
w1
else
:
w1s_grad
=
w1
.
grad
w1_grads
=
[
torch
.
empty_like
(
w1
)
for
_
in
range
(
dist
.
get_world_size
())]
w1_grads
=
[
torch
.
empty_like
(
w1
)
for
_
in
range
(
dist
.
get_world_size
())]
dist
.
all_gather
(
w1_grads
,
w1
.
grad
)
dist
.
all_gather
(
w1_grads
,
w1
s_
grad
)
assert
torch
.
equal
(
w1_grads
[
0
],
w1_grads
[
1
])
assert
torch
.
equal
(
w1_grads
[
0
],
w1_grads
[
1
])
w2_grads
=
[
torch
.
empty_like
(
w2
)
for
_
in
range
(
dist
.
get_world_size
())]
w2_grads
=
[
torch
.
empty_like
(
w2
)
for
_
in
range
(
dist
.
get_world_size
())]
dist
.
all_gather
(
w2_grads
,
w2
.
grad
)
dist
.
all_gather
(
w2_grads
,
w2
.
grad
)
...
@@ -74,8 +79,7 @@ def run_dist(rank, world_size, port):
...
@@ -74,8 +79,7 @@ def run_dist(rank, world_size, port):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
set_seed
(
dist
.
get_rank
())
set_seed
(
dist
.
get_rank
())
run_fwd_bwd
(
ColoDDP
,
init_ddp
)
run_fwd_bwd
(
ColoDDP
,
init_ddp
)
run_fwd_bwd
(
ZeroDDP
,
partial
(
init_ddpv2
,
use_chunk
=
False
))
run_fwd_bwd
(
ZeroDDP
,
init_ddpv2
)
run_fwd_bwd
(
ZeroDDP
,
partial
(
init_ddpv2
,
use_chunk
=
True
))
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_ddp/test_ddp_state_dict.py
View file @
5be118f4
...
@@ -8,14 +8,11 @@ from colossalai.testing import rerun_if_address_is_in_use
...
@@ -8,14 +8,11 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.gemini
import
ChunkManager
from
functools
import
partial
from
functools
import
partial
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel
import
ZeroDDP
,
ColoDDP
from
colossalai.nn.parallel
import
ColoDDP
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
colossalai.tensor
import
ProcessGroup
,
ColoParameter
from
colossalai.tensor
import
ProcessGroup
,
ColoParameter
from
colossalai.testing
import
parameterize
def
check_state_dict_equal
(
state_dict
:
OrderedDict
,
other_state_dict
:
OrderedDict
):
def
check_state_dict_equal
(
state_dict
:
OrderedDict
,
other_state_dict
:
OrderedDict
):
...
@@ -30,42 +27,11 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
...
@@ -30,42 +27,11 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
assert
torch
.
equal
(
t1
,
temp_t2
),
"
\t
{}
\n\t
{}"
.
format
(
t1
,
temp_t2
)
assert
torch
.
equal
(
t1
,
temp_t2
),
"
\t
{}
\n\t
{}"
.
format
(
t1
,
temp_t2
)
def
check_model_equal
(
model_a
,
model_b
,
allow_empty
:
bool
=
False
,
same_dtype
:
bool
=
True
):
for
(
na
,
pa
),
(
nb
,
pb
)
in
zip
(
model_a
.
named_parameters
(),
model_b
.
named_parameters
()):
assert
na
==
nb
if
not
allow_empty
:
assert
pa
.
storage
().
size
()
>
0
assert
pb
.
storage
().
size
()
>
0
else
:
if
pa
.
storage
().
size
()
==
0
or
pb
.
storage
().
size
()
==
0
:
continue
if
same_dtype
:
assert
pa
.
dtype
==
pb
.
dtype
temp_pb
=
pb
else
:
temp_pb
=
pb
.
to
(
pa
.
dtype
)
assert
torch
.
equal
(
pa
,
temp_pb
),
"Parameter '{}' is not equal.
\n
{} {}"
.
format
(
na
,
pa
,
pb
)
def
init_ddp
(
module
:
torch
.
nn
.
Module
)
->
ColoDDP
:
def
init_ddp
(
module
:
torch
.
nn
.
Module
)
->
ColoDDP
:
pg
=
ProcessGroup
()
pg
=
ProcessGroup
()
return
ColoDDP
(
module
,
process_group
=
pg
)
return
ColoDDP
(
module
,
process_group
=
pg
)
def
init_ddpv2
(
module
:
torch
.
nn
.
Module
,
use_chunk
:
bool
=
False
,
use_zero
:
bool
=
False
,
placement_policy
:
str
=
'cuda'
)
->
ZeroDDP
:
pg
=
ProcessGroup
()
chunk_size
=
ChunkManager
.
search_chunk_size
(
module
,
64
,
4
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
use_zero
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
return
ZeroDDP
(
module
,
gemini_manager
)
def
run_ddp_state_dict
():
def
run_ddp_state_dict
():
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
@@ -88,44 +54,9 @@ def run_ddp_state_dict():
...
@@ -88,44 +54,9 @@ def run_ddp_state_dict():
check_state_dict_equal
(
torch_state_dict
,
state_dict
)
check_state_dict_equal
(
torch_state_dict
,
state_dict
)
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
@
parameterize
(
'only_rank_0'
,
[
False
,
True
])
def
run_zero_state_dict
(
use_chunk
,
placement_policy
,
use_zero
,
only_rank_0
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
torch_model
=
model_builder
().
cuda
()
org_torch_model
=
copy
.
deepcopy
(
torch_model
)
torch_state_dict
=
torch_model
.
state_dict
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
init_ddpv2
(
model
,
use_chunk
,
use_zero
,
placement_policy
)
for
param
in
model
.
parameters
():
if
isinstance
(
param
,
ColoParameter
):
assert
param
.
get_process_group
()
is
not
None
model
.
load_state_dict
(
torch_state_dict
,
strict
=
False
)
check_model_equal
(
model
,
torch_model
,
allow_empty
=
True
,
same_dtype
=
False
)
for
param
in
model
.
parameters
():
if
isinstance
(
param
,
ColoParameter
):
assert
param
.
get_process_group
()
is
not
None
pg
=
ProcessGroup
()
state_dict
=
model
.
state_dict
(
only_rank_0
=
only_rank_0
)
if
not
only_rank_0
or
pg
.
dp_local_rank
()
==
0
:
torch_model
.
load_state_dict
(
state_dict
,
strict
=
False
)
check_model_equal
(
torch_model
,
org_torch_model
,
allow_empty
=
False
,
same_dtype
=
True
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_ddp_state_dict
()
run_ddp_state_dict
()
run_zero_state_dict
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_gemini/test_stateful_tensor_container.py
deleted
100644 → 0
View file @
f9217336
import
pytest
import
torch
from
colossalai.gemini.stateful_tensor
import
TensorState
,
StatefulTensor
from
colossalai.gemini.stateful_tensor_container
import
QueueSTContainer
,
HeapSTContainer
@
pytest
.
mark
.
dist
def
test_stateful_tensor_container
():
st1
=
StatefulTensor
(
torch
.
randn
(
1
,
device
=
'cuda'
))
st2
=
StatefulTensor
(
torch
.
randn
(
2
,
device
=
'cuda'
))
st3
=
StatefulTensor
(
torch
.
randn
(
3
,
device
=
'cuda'
))
stateful_tensor_list
=
[
st1
,
st2
,
st3
]
step_list
=
[
st1
,
st2
,
st3
,
st3
,
st2
,
st1
]
compute_step_dict
=
dict
()
compute_step_dict
[
st1
]
=
[
0
,
5
]
compute_step_dict
[
st2
]
=
[
1
,
4
]
compute_step_dict
[
st3
]
=
[
2
,
3
]
def
run_queue_test
():
# test queue container
queue_container
=
QueueSTContainer
(
compute_step_dict
,
6
)
queue_container
.
create
(
stateful_tensor_list
)
res_list
=
[]
for
i
in
range
(
6
):
stateful_tensor
=
step_list
[
i
]
stateful_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
st_out
=
queue_container
.
pop
()
st_out
.
move_to
(
torch
.
device
(
'cpu'
))
res_list
.
append
(
st_out
.
payload
.
size
(
0
))
stateful_tensor
.
move_to
(
torch
.
device
(
'cuda'
))
queue_container
.
push
(
stateful_tensor
,
i
)
stateful_tensor
.
trans_state
(
TensorState
.
HOLD
)
assert
res_list
==
[
2
,
3
,
1
,
2
,
3
,
2
]
run_queue_test
()
def
run_heap_test
():
# test heap container
st1
.
move_to
(
torch
.
device
(
'cuda'
))
st2
.
move_to
(
torch
.
device
(
'cuda'
))
st3
.
move_to
(
torch
.
device
(
'cuda'
))
heap_container
=
HeapSTContainer
(
compute_step_dict
,
6
)
heap_container
.
create
(
stateful_tensor_list
)
res_list
=
[]
for
i
in
range
(
6
):
stateful_tensor
=
step_list
[
i
]
stateful_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
st_out
=
heap_container
.
pop
()
if
st_out
is
not
None
:
res_list
.
append
(
st_out
.
payload
.
size
(
0
))
st_out
.
move_to
(
torch
.
device
(
'cpu'
))
stateful_tensor
.
move_to
(
torch
.
device
(
'cuda'
))
heap_container
.
push
(
stateful_tensor
,
i
)
stateful_tensor
.
trans_state
(
TensorState
.
HOLD
)
assert
res_list
==
[
3
,
1
,
2
,
3
,
2
]
run_heap_test
()
if
__name__
==
'__main__'
:
test_stateful_tensor_container
()
tests/test_gemini/update/test_chunk_mgrv2.py
View file @
5be118f4
...
@@ -3,7 +3,7 @@ import colossalai
...
@@ -3,7 +3,7 @@ import colossalai
import
pytest
import
pytest
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
functools
import
partial
from
functools
import
partial
from
colossalai.gemini.
update
import
ChunkManager
V2
from
colossalai.gemini.
chunk
import
ChunkManager
from
colossalai.testing
import
rerun_if_address_is_in_use
,
parameterize
from
colossalai.testing
import
rerun_if_address_is_in_use
,
parameterize
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
ProcessGroup
,
ColoTensor
,
ColoTensorSpec
from
colossalai.tensor
import
ProcessGroup
,
ColoTensor
,
ColoTensorSpec
...
@@ -19,23 +19,17 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
...
@@ -19,23 +19,17 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
def
exam_chunk_memory
(
keep_gathered
,
pin_memory
):
def
exam_chunk_memory
(
keep_gathered
,
pin_memory
):
pg
=
ProcessGroup
()
pg
=
ProcessGroup
()
debug_print
([
0
],
"keep_gathered: {}, pin_memory: {}"
.
format
(
debug_print
([
0
],
"keep_gathered: {}, pin_memory: {}"
.
format
(
keep_gathered
,
pin_memory
))
keep_gathered
,
pin_memory
))
params
=
[
ColoTensor
(
torch
.
rand
(
8
,
8
),
spec
=
ColoTensorSpec
(
pg
))
for
_
in
range
(
3
)]
params
=
[
ColoTensor
(
torch
.
rand
(
8
,
8
),
spec
=
ColoTensorSpec
(
pg
))
for
_
in
range
(
3
)]
config
=
{
config
=
{
2
:
dict
(
chunk_size
=
128
,
keep_gathered
=
keep_gathered
)}
2
:
dict
(
chunk_size
=
128
,
chunk_manager
=
ChunkManager
(
config
)
keep_gathered
=
keep_gathered
)
}
chunk_manager
=
ChunkManagerV2
(
config
,
pin_memory
=
pin_memory
)
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
0
for
p
in
params
:
for
p
in
params
:
chunk_manager
.
append_tensor
(
p
,
'param'
,
2
)
chunk_manager
.
append_tensor
(
p
,
'param'
,
2
,
pin_memory
=
pin_memory
)
chunk_manager
.
close_all_groups
()
chunk_manager
.
close_all_groups
()
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
CPU_MEM
[
keep_gathered
][
pin_memory
]
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
CPU_MEM
[
keep_gathered
][
pin_memory
]
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
CUDA_MEM_0
[
keep_gathered
]
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
CUDA_MEM_0
[
keep_gathered
]
...
...
tests/test_gemini/update/test_chunkv2.py
View file @
5be118f4
...
@@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device
...
@@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor
import
ColoParameter
from
colossalai.tensor
import
ColoParameter
from
colossalai.gemini
import
TensorState
from
colossalai.gemini
import
TensorState
from
colossalai.gemini.
update
import
Chunk
V2
from
colossalai.gemini.
chunk
import
Chunk
def
dist_sum
(
x
):
def
dist_sum
(
x
):
...
@@ -38,14 +38,12 @@ def check_euqal(param, param_cp):
...
@@ -38,14 +38,12 @@ def check_euqal(param, param_cp):
def
exam_chunk_basic
(
init_device
,
keep_gathered
,
pin_memory
):
def
exam_chunk_basic
(
init_device
,
keep_gathered
,
pin_memory
):
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ColoProcessGroup
()
pg
=
ColoProcessGroup
()
my_chunk
=
ChunkV2
(
my_chunk
=
Chunk
(
chunk_size
=
1024
,
chunk_size
=
1024
,
process_group
=
pg
,
process_group
=
pg
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
init_device
=
init_device
,
init_device
=
init_device
,
keep_gathered
=
keep_gathered
,
keep_gathered
=
keep_gathered
,
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
param_list
=
[]
param_list
=
[]
param_cp_list
=
[]
param_cp_list
=
[]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment