Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
372f7914
Unverified
Commit
372f7914
authored
Jun 29, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 29, 2022
Browse files
[refactor] move chunk and chunkmgr to directory gemini (#1182)
parent
6b2f2ab9
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
335 additions
and
331 deletions
+335
-331
colossalai/gemini/__init__.py
colossalai/gemini/__init__.py
+6
-1
colossalai/gemini/chunk.py
colossalai/gemini/chunk.py
+315
-0
colossalai/gemini/chunk_mgr.py
colossalai/gemini/chunk_mgr.py
+3
-310
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+2
-2
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+1
-1
colossalai/gemini/placement_policy.py
colossalai/gemini/placement_policy.py
+1
-1
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+1
-1
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+0
-1
colossalai/zero/utils/zero_hook_v2.py
colossalai/zero/utils/zero_hook_v2.py
+1
-1
tests/test_ddp/test_ddp_ignore_params.py
tests/test_ddp/test_ddp_ignore_params.py
+1
-1
tests/test_ddp/test_ddp_state_dict.py
tests/test_ddp/test_ddp_state_dict.py
+1
-1
tests/test_ddp/test_reducer.py
tests/test_ddp/test_reducer.py
+0
-7
tests/test_tensor/test_chunk.py
tests/test_tensor/test_chunk.py
+1
-1
tests/test_tensor/test_zero_optim.py
tests/test_tensor/test_zero_optim.py
+1
-1
tests/test_zero/test_zero_optim_state_dict.py
tests/test_zero/test_zero_optim_state_dict.py
+1
-2
No files found.
colossalai/gemini/__init__.py
View file @
372f7914
from
.chunk
import
TensorInfo
,
Chunk
,
TensorState
from
.chunk_mgr
import
ChunkManager
from
.stateful_tensor_mgr
import
StatefulTensorMgr
from
.tensor_placement_policy
import
TensorPlacementPolicyFactory
from
.gemini_mgr
import
GeminiManager
__all__
=
[
'StatefulTensorMgr'
,
'TensorPlacementPolicyFactory'
,
'GeminiManager'
]
__all__
=
[
'StatefulTensorMgr'
,
'TensorPlacementPolicyFactory'
,
'GeminiManager'
,
'ChunkManager'
,
'TensorInfo'
,
'Chunk'
,
'TensorState'
]
colossalai/gemini/chunk.py
0 → 100644
View file @
372f7914
import
torch
import
torch.distributed
as
dist
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Optional
,
Dict
,
List
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
colossalai.utils
import
get_current_device
class
TensorState
(
Enum
):
FREE
=
0
COMPUTE
=
1
HOLD
=
2
HOLD_AFTER_BWD
=
3
READY_FOR_REDUCE
=
4
STATE_TRANS
=
((
TensorState
.
FREE
,
TensorState
.
HOLD
),
(
TensorState
.
FREE
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD
,
TensorState
.
FREE
),
(
TensorState
.
HOLD
,
TensorState
.
COMPUTE
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD_AFTER_BWD
),
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
READY_FOR_REDUCE
,
TensorState
.
HOLD
))
@
dataclass
class
TensorInfo
:
state
:
TensorState
offset
:
int
end
:
int
class
ChunkFullError
(
Exception
):
pass
def
is_storage_empty
(
tensor
:
torch
.
Tensor
)
->
bool
:
return
tensor
.
storage
().
size
()
==
0
def
free_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
not
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
0
)
def
alloc_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
tensor
.
numel
())
class
Chunk
:
"""
A chunk is a contiguous memory space which contains multiple tensors.
Args:
chunk_size (int): the number of elements in a chunk
src_rank (int): the process which owns the chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
"""
def
__init__
(
self
,
chunk_size
:
int
,
src_rank
:
int
,
dtype
:
torch
.
dtype
,
init_device
:
Optional
[
torch
.
device
]
=
None
,
force_data_on_cuda
:
bool
=
False
)
->
None
:
self
.
size
=
chunk_size
self
.
utilized_size
=
0
self
.
src_rank
=
src_rank
self
.
is_src_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
src_rank
self
.
global_src_rank
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
DATA
)[
src_rank
]
self
.
dtype
=
dtype
device
=
init_device
or
get_current_device
()
if
force_data_on_cuda
:
self
.
data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
get_current_device
())
self
.
_cpu_data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
)
if
device
.
type
==
'cuda'
:
free_storage
(
self
.
_cpu_data
)
else
:
free_storage
(
self
.
data
)
else
:
self
.
data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
self
.
_cpu_data
=
None
# we only keep the chunk in full in the process by which the tensor is owned
if
not
self
.
is_src_rank
:
free_storage
(
self
.
_payload
)
# each tensor is associated with a TensorInfo to track meta info
self
.
tensors_info
:
Dict
[
torch
.
Tensor
,
TensorInfo
]
=
{}
self
.
mem
=
self
.
size
*
self
.
data
.
element_size
()
def
append
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
"""
Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
assert
tensor
.
dtype
==
self
.
dtype
new_utilized_size
=
self
.
utilized_size
+
tensor
.
numel
()
# raise exception when the chunk size is exceeded
if
new_utilized_size
>
self
.
size
:
raise
ChunkFullError
# set tensor state
tensor_state
=
TensorState
.
FREE
# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
if
self
.
is_src_rank
:
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
flatten
())
tensor_state
=
TensorState
.
HOLD
assert
type
(
self
.
_payload
)
==
torch
.
Tensor
,
"copy_tensor_to_chunk_slice must use a torch tensor"
tensor
.
data
=
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
else
:
tensor
.
storage
().
resize_
(
0
)
self
.
tensors_info
[
tensor
]
=
TensorInfo
(
tensor_state
,
self
.
utilized_size
,
new_utilized_size
)
self
.
utilized_size
=
new_utilized_size
def
release
(
self
)
->
None
:
"""
Release the memory space on processes which do not own the chunk.
"""
if
not
self
.
is_src_rank
:
free_storage
(
self
.
_payload
)
self
.
_update_tensors_state
(
TensorState
.
FREE
)
def
_update_tensors_ptr
(
self
)
->
None
:
assert
type
(
self
.
_payload
)
==
torch
.
Tensor
for
tensor
,
tensor_info
in
self
.
tensors_info
.
items
():
tensor
.
data
=
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
def
_update_tensors_state
(
self
,
next_state
:
TensorState
,
prev_state
:
Optional
[
TensorState
]
=
None
):
for
tensor_info
in
self
.
tensors_info
.
values
():
if
prev_state
is
None
or
tensor_info
.
state
==
prev_state
:
tensor_info
.
state
=
next_state
def
access
(
self
)
->
None
:
"""
Broadcast the chunk to synchronize the tensors across data parallel processes.
"""
# recover the chunk on non-owner processes
# and broadcast the chunk from the source to all processes
if
not
self
.
is_src_rank
:
alloc_storage
(
self
.
_payload
)
self
.
move_device
(
get_current_device
(),
update_ptr
=
False
)
dist
.
broadcast
(
self
.
data
,
self
.
global_src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
# update tensor meta info
self
.
_update_tensors_ptr
()
if
not
self
.
is_src_rank
:
self
.
_update_tensors_state
(
TensorState
.
HOLD
,
prev_state
=
TensorState
.
FREE
)
def
move_device
(
self
,
device
:
torch
.
device
,
update_ptr
:
bool
=
True
)
->
None
:
"""
Move the chunk to a target device.
Args:
device (torch.device): the target device for data movement.
"""
if
self
.
_payload
.
device
==
device
:
return
if
self
.
_cpu_data
is
None
:
self
.
data
.
data
=
self
.
data
.
to
(
device
)
else
:
if
device
.
type
==
'cuda'
:
# cpu -> cuda
src
=
self
.
_cpu_data
dest
=
self
.
data
else
:
# cuda -> cpu
src
=
self
.
data
dest
=
self
.
_cpu_data
alloc_storage
(
dest
)
dest
.
copy_
(
src
)
free_storage
(
src
)
if
update_ptr
:
self
.
_update_tensors_ptr
()
def
reduce
(
self
,
is_all_reduce
:
bool
=
False
)
->
None
:
"""
Reduce or all-reduce the chunk.
Args:
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
"""
self
.
move_device
(
get_current_device
(),
update_ptr
=
False
)
if
is_all_reduce
:
dist
.
all_reduce
(
self
.
data
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
else
:
dist
.
reduce
(
self
.
data
,
self
.
global_src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
self
.
_update_tensors_ptr
()
self
.
_update_tensors_state
(
TensorState
.
HOLD
)
def
tensor_trans_state
(
self
,
tensor
:
torch
.
Tensor
,
tensor_state
:
TensorState
)
->
None
:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
assert
tensor
!=
TensorState
.
FREE
,
'Can only set a chunk of tensors to FREE'
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if
(
self
.
tensors_info
[
tensor
].
state
,
tensor_state
)
not
in
STATE_TRANS
:
# print(
# f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self
.
tensors_info
[
tensor
].
state
=
tensor_state
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data_slice
:
torch
.
Tensor
)
->
None
:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
tensor_info
=
self
.
tensors_info
[
tensor
]
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
flatten
())
tensor
.
data
=
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
@
property
def
can_release
(
self
)
->
bool
:
"""
Check whether the chunk can be released.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
!=
TensorState
.
HOLD
:
return
False
return
True
@
property
def
can_move_device
(
self
)
->
bool
:
"""
Check whether the chunk can be moved across devices.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
in
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
):
return
False
return
True
@
property
def
can_reduce
(
self
)
->
bool
:
"""
Check whether the chunk can be reduced.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
!=
TensorState
.
READY_FOR_REDUCE
:
return
False
return
True
@
property
def
is_empty
(
self
)
->
bool
:
"""
Check whether the chunk is empty.
"""
return
is_storage_empty
(
self
.
_payload
)
def
__repr__
(
self
)
->
str
:
return
f
'Chunk: src rank=
{
self
.
src_rank
}
,size=
{
self
.
size
}
, utilization=
{
self
.
utilized_size
/
self
.
size
*
100
:.
2
f
}
%, freed=
{
self
.
is_empty
}
, tensor states=
{
[
info
.
state
.
name
for
info
in
self
.
tensors_info
.
values
()]
}
'
@
property
def
has_inf_or_nan
(
self
)
->
bool
:
"""
Check if the chunk has inf or nan values.
"""
return
torch
.
isinf
(
self
.
_payload
[:
self
.
utilized_size
]).
any
().
item
()
or
\
torch
.
isnan
(
self
.
_payload
[:
self
.
utilized_size
]).
any
().
item
()
def
copy_
(
self
,
dest_chunk
:
'Chunk'
):
"""
Copy the data of this chunk to a destination chunk.
"""
assert
not
self
.
is_empty
assert
not
dest_chunk
.
is_empty
assert
self
.
size
==
dest_chunk
.
size
assert
self
.
utilized_size
==
dest_chunk
.
utilized_size
self
.
_payload
.
copy_
(
dest_chunk
.
_payload
)
self
.
_update_tensors_ptr
()
@
property
def
device_type
(
self
)
->
str
:
"""
Get the device type of the chunk.
"""
return
self
.
_payload
.
device
.
type
def
__hash__
(
self
)
->
int
:
return
hash
(
id
(
self
))
def
__eq__
(
self
,
__o
:
object
)
->
bool
:
return
self
is
__o
def
get_tensors
(
self
)
->
List
[
torch
.
Tensor
]:
return
list
(
self
.
tensors_info
.
keys
())
@
property
def
_payload
(
self
)
->
torch
.
Tensor
:
if
self
.
_cpu_data
is
None
or
is_storage_empty
(
self
.
_cpu_data
):
return
self
.
data
return
self
.
_cpu_data
colossalai/
tensor
/chunk.py
→
colossalai/
gemini
/chunk
_mgr
.py
View file @
372f7914
import
torch
import
torch.distributed
as
dist
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Optional
,
Dict
,
Deque
,
Set
,
List
,
Tuple
,
Iterable
from
collections
import
deque
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
get_current_device
class
TensorState
(
Enum
):
FREE
=
0
COMPUTE
=
1
HOLD
=
2
HOLD_AFTER_BWD
=
3
READY_FOR_REDUCE
=
4
STATE_TRANS
=
((
TensorState
.
FREE
,
TensorState
.
HOLD
),
(
TensorState
.
FREE
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD
,
TensorState
.
FREE
),
(
TensorState
.
HOLD
,
TensorState
.
COMPUTE
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD_AFTER_BWD
),
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
READY_FOR_REDUCE
,
TensorState
.
HOLD
))
@
dataclass
class
TensorInfo
:
state
:
TensorState
offset
:
int
end
:
int
class
ChunkFullError
(
Exception
):
pass
def
is_storage_empty
(
tensor
:
torch
.
Tensor
)
->
bool
:
return
tensor
.
storage
().
size
()
==
0
def
free_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
not
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
0
)
def
alloc_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
tensor
.
numel
())
class
Chunk
:
"""
A chunk is a contiguous memory space which contains multiple tensors.
Args:
chunk_size (int): the number of elements in a chunk
src_rank (int): the process which owns the chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
"""
def
__init__
(
self
,
chunk_size
:
int
,
src_rank
:
int
,
dtype
:
torch
.
dtype
,
init_device
:
Optional
[
torch
.
device
]
=
None
,
force_data_on_cuda
:
bool
=
False
)
->
None
:
self
.
size
=
chunk_size
self
.
utilized_size
=
0
self
.
src_rank
=
src_rank
self
.
is_src_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
src_rank
self
.
global_src_rank
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
DATA
)[
src_rank
]
self
.
dtype
=
dtype
device
=
init_device
or
get_current_device
()
if
force_data_on_cuda
:
self
.
data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
get_current_device
())
self
.
_cpu_data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
)
if
device
.
type
==
'cuda'
:
free_storage
(
self
.
_cpu_data
)
else
:
free_storage
(
self
.
data
)
else
:
self
.
data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
self
.
_cpu_data
=
None
# we only keep the chunk in full in the process by which the tensor is owned
if
not
self
.
is_src_rank
:
free_storage
(
self
.
_payload
)
# each tensor is associated with a TensorInfo to track meta info
self
.
tensors_info
:
Dict
[
torch
.
Tensor
,
TensorInfo
]
=
{}
self
.
mem
=
self
.
size
*
self
.
data
.
element_size
()
def
append
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
"""
Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
assert
tensor
.
dtype
==
self
.
dtype
new_utilized_size
=
self
.
utilized_size
+
tensor
.
numel
()
# raise exception when the chunk size is exceeded
if
new_utilized_size
>
self
.
size
:
raise
ChunkFullError
# set tensor state
tensor_state
=
TensorState
.
FREE
# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
if
self
.
is_src_rank
:
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
flatten
())
tensor_state
=
TensorState
.
HOLD
assert
type
(
self
.
_payload
)
==
torch
.
Tensor
,
"copy_tensor_to_chunk_slice must use a torch tensor"
tensor
.
data
=
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
else
:
tensor
.
storage
().
resize_
(
0
)
self
.
tensors_info
[
tensor
]
=
TensorInfo
(
tensor_state
,
self
.
utilized_size
,
new_utilized_size
)
self
.
utilized_size
=
new_utilized_size
def
release
(
self
)
->
None
:
"""
Release the memory space on processes which do not own the chunk.
"""
if
not
self
.
is_src_rank
:
free_storage
(
self
.
_payload
)
self
.
_update_tensors_state
(
TensorState
.
FREE
)
def
_update_tensors_ptr
(
self
)
->
None
:
assert
type
(
self
.
_payload
)
==
torch
.
Tensor
for
tensor
,
tensor_info
in
self
.
tensors_info
.
items
():
tensor
.
data
=
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
def
_update_tensors_state
(
self
,
next_state
:
TensorState
,
prev_state
:
Optional
[
TensorState
]
=
None
):
for
tensor_info
in
self
.
tensors_info
.
values
():
if
prev_state
is
None
or
tensor_info
.
state
==
prev_state
:
tensor_info
.
state
=
next_state
def
access
(
self
)
->
None
:
"""
Broadcast the chunk to synchronize the tensors across data parallel processes.
"""
# recover the chunk on non-owner processes
# and broadcast the chunk from the source to all processes
if
not
self
.
is_src_rank
:
alloc_storage
(
self
.
_payload
)
self
.
move_device
(
get_current_device
(),
update_ptr
=
False
)
dist
.
broadcast
(
self
.
data
,
self
.
global_src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
# update tensor meta info
self
.
_update_tensors_ptr
()
if
not
self
.
is_src_rank
:
self
.
_update_tensors_state
(
TensorState
.
HOLD
,
prev_state
=
TensorState
.
FREE
)
def
move_device
(
self
,
device
:
torch
.
device
,
update_ptr
:
bool
=
True
)
->
None
:
"""
Move the chunk to a target device.
Args:
device (torch.device): the target device for data movement.
"""
if
self
.
_payload
.
device
==
device
:
return
if
self
.
_cpu_data
is
None
:
self
.
data
.
data
=
self
.
data
.
to
(
device
)
else
:
if
device
.
type
==
'cuda'
:
# cpu -> cuda
src
=
self
.
_cpu_data
dest
=
self
.
data
else
:
# cuda -> cpu
src
=
self
.
data
dest
=
self
.
_cpu_data
alloc_storage
(
dest
)
dest
.
copy_
(
src
)
free_storage
(
src
)
if
update_ptr
:
self
.
_update_tensors_ptr
()
def
reduce
(
self
,
is_all_reduce
:
bool
=
False
)
->
None
:
"""
Reduce or all-reduce the chunk.
Args:
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
"""
self
.
move_device
(
get_current_device
(),
update_ptr
=
False
)
if
is_all_reduce
:
dist
.
all_reduce
(
self
.
data
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
else
:
dist
.
reduce
(
self
.
data
,
self
.
global_src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
self
.
_update_tensors_ptr
()
self
.
_update_tensors_state
(
TensorState
.
HOLD
)
def
tensor_trans_state
(
self
,
tensor
:
torch
.
Tensor
,
tensor_state
:
TensorState
)
->
None
:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
assert
tensor
!=
TensorState
.
FREE
,
'Can only set a chunk of tensors to FREE'
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if
(
self
.
tensors_info
[
tensor
].
state
,
tensor_state
)
not
in
STATE_TRANS
:
# print(
# f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self
.
tensors_info
[
tensor
].
state
=
tensor_state
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data_slice
:
torch
.
Tensor
)
->
None
:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
tensor_info
=
self
.
tensors_info
[
tensor
]
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
flatten
())
tensor
.
data
=
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
@
property
def
can_release
(
self
)
->
bool
:
"""
Check whether the chunk can be released.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
!=
TensorState
.
HOLD
:
return
False
return
True
@
property
def
can_move_device
(
self
)
->
bool
:
"""
Check whether the chunk can be moved across devices.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
in
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
):
return
False
return
True
@
property
def
can_reduce
(
self
)
->
bool
:
"""
Check whether the chunk can be reduced.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
!=
TensorState
.
READY_FOR_REDUCE
:
return
False
return
True
@
property
def
is_empty
(
self
)
->
bool
:
"""
Check whether the chunk is empty.
"""
return
is_storage_empty
(
self
.
_payload
)
def
__repr__
(
self
)
->
str
:
return
f
'Chunk: src rank=
{
self
.
src_rank
}
,size=
{
self
.
size
}
, utilization=
{
self
.
utilized_size
/
self
.
size
*
100
:.
2
f
}
%, freed=
{
self
.
is_empty
}
, tensor states=
{
[
info
.
state
.
name
for
info
in
self
.
tensors_info
.
values
()]
}
'
@
property
def
has_inf_or_nan
(
self
)
->
bool
:
"""
Check if the chunk has inf or nan values.
"""
return
torch
.
isinf
(
self
.
_payload
[:
self
.
utilized_size
]).
any
().
item
()
or
\
torch
.
isnan
(
self
.
_payload
[:
self
.
utilized_size
]).
any
().
item
()
def
copy_
(
self
,
dest_chunk
:
'Chunk'
):
"""
Copy the data of this chunk to a destination chunk.
"""
assert
not
self
.
is_empty
assert
not
dest_chunk
.
is_empty
assert
self
.
size
==
dest_chunk
.
size
assert
self
.
utilized_size
==
dest_chunk
.
utilized_size
self
.
_payload
.
copy_
(
dest_chunk
.
_payload
)
self
.
_update_tensors_ptr
()
@
property
def
device_type
(
self
)
->
str
:
"""
Get the device type of the chunk.
"""
return
self
.
_payload
.
device
.
type
def
__hash__
(
self
)
->
int
:
return
hash
(
id
(
self
))
def
__eq__
(
self
,
__o
:
object
)
->
bool
:
return
self
is
__o
def
get_tensors
(
self
)
->
List
[
torch
.
Tensor
]:
return
list
(
self
.
tensors_info
.
keys
())
@
property
def
_payload
(
self
)
->
torch
.
Tensor
:
if
self
.
_cpu_data
is
None
or
is_storage_empty
(
self
.
_cpu_data
):
return
self
.
data
return
self
.
_cpu_data
from
.chunk
import
Chunk
,
ChunkFullError
,
TensorState
class
ChunkManager
:
...
...
colossalai/gemini/gemini_mgr.py
View file @
372f7914
...
...
@@ -3,8 +3,8 @@ import functools
from
.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
typing
import
List
,
Optional
,
Tuple
from
time
import
time
from
colossalai.
tensor.chunk
import
Chunk
,
ChunkManager
from
.placement_policy
import
PlacementPolicy
,
PlacementPolicyFactory
from
colossalai.
gemini
import
Chunk
,
ChunkManager
from
.placement_policy
import
PlacementPolicyFactory
class
GeminiManager
:
...
...
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
372f7914
...
...
@@ -2,7 +2,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from
colossalai.utils.memory
import
colo_device_memory_used
,
colo_device_memory_capacity
from
colossalai.utils
import
get_current_device
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.
tensor
import
ChunkManager
from
colossalai.
gemini
import
ChunkManager
import
torch
import
time
...
...
colossalai/gemini/placement_policy.py
View file @
372f7914
...
...
@@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity
from
colossalai.gemini.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
typing
import
Type
import
functools
from
colossalai.
tensor.chunk
import
Chunk
,
ChunkManager
from
colossalai.
gemini
import
Chunk
,
ChunkManager
class
PlacementPolicy
(
ABC
):
...
...
colossalai/nn/parallel/data_parallel.py
View file @
372f7914
...
...
@@ -5,7 +5,7 @@ from colossalai.core import global_context as gpc
from
colossalai.context
import
ParallelMode
from
functools
import
partial
from
colossalai.zero.utils.zero_hook_v2
import
ZeROHookV2
from
colossalai.
tensor
.chunk
import
TensorState
,
Chunk
from
colossalai.
gemini
.chunk
import
TensorState
,
Chunk
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
typing
import
Dict
,
Iterable
,
List
,
Optional
...
...
colossalai/tensor/__init__.py
View file @
372f7914
...
...
@@ -5,7 +5,6 @@ from .colo_parameter import ColoParameter
from
.utils
import
convert_parameter
,
named_params_with_colotensor
from
.dist_spec_mgr
import
DistSpecManager
from
.param_op_hook
import
ParamOpHook
,
ParamOpHookManager
from
.chunk
import
ChunkManager
,
TensorState
from
.
import
distspec
from
.process_group
import
ProcessGroup
...
...
colossalai/zero/utils/zero_hook_v2.py
View file @
372f7914
import
torch
from
colossalai.tensor.param_op_hook
import
ParamOpHook
from
colossalai.
tensor.chunk
import
ChunkManager
,
TensorState
from
colossalai.
gemini
import
TensorState
from
enum
import
Enum
from
typing
import
List
from
contextlib
import
contextmanager
...
...
tests/test_ddp/test_ddp_ignore_params.py
View file @
372f7914
...
...
@@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.
tensor
import
ChunkManager
from
colossalai.
gemini
import
ChunkManager
from
functools
import
partial
from
colossalai.nn.parallel
import
ColoDDP
,
ZeroDDP
from
colossalai.gemini.gemini_mgr
import
GeminiManager
...
...
tests/test_ddp/test_ddp_state_dict.py
View file @
372f7914
...
...
@@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.
tensor
import
ChunkManager
from
colossalai.
gemini
import
ChunkManager
from
functools
import
partial
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel
import
ZeroDDP
,
ColoDDP
...
...
tests/test_ddp/test_reducer.py
View file @
372f7914
...
...
@@ -5,14 +5,7 @@ import torch.multiprocessing as mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ChunkManager
from
functools
import
partial
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel
import
ZeroDDP
,
ColoDDP
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
typing
import
Callable
from
collections
import
OrderedDict
from
colossalai.nn.parallel.reducer
import
Reducer
import
torch.distributed
as
dist
from
torch.distributed.distributed_c10d
import
_get_default_group
...
...
tests/test_tensor/test_chunk.py
View file @
372f7914
...
...
@@ -4,7 +4,7 @@ import pytest
import
torch.multiprocessing
as
mp
from
typing
import
List
from
functools
import
partial
from
colossalai.
tensor
import
ChunkManager
from
colossalai.
gemini
import
ChunkManager
from
colossalai.testing
import
rerun_if_address_is_in_use
,
parameterize
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
...
...
tests/test_tensor/test_zero_optim.py
View file @
372f7914
...
...
@@ -7,7 +7,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.
tensor
import
ChunkManager
from
colossalai.
gemini
import
ChunkManager
from
colossalai.core
import
global_context
as
gpc
from
functools
import
partial
from
_utils
import
tensor_equal
,
set_seed
,
tensor_shard_equal
...
...
tests/test_zero/test_zero_optim_state_dict.py
View file @
372f7914
...
...
@@ -7,13 +7,12 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ChunkManager
from
colossalai.core
import
global_context
as
gpc
from
functools
import
partial
from
tests.test_tensor._utils
import
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel.data_parallel
import
ZeroDDP
from
colossalai.gemini
import
GeminiManager
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
from
colossalai.testing
import
parameterize
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.zero
import
ZeroOptimizer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment