Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
372f7914
Unverified
Commit
372f7914
authored
Jun 29, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 29, 2022
Browse files
[refactor] move chunk and chunkmgr to directory gemini (#1182)
parent
6b2f2ab9
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
335 additions
and
331 deletions
+335
-331
colossalai/gemini/__init__.py
colossalai/gemini/__init__.py
+6
-1
colossalai/gemini/chunk.py
colossalai/gemini/chunk.py
+315
-0
colossalai/gemini/chunk_mgr.py
colossalai/gemini/chunk_mgr.py
+3
-310
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+2
-2
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+1
-1
colossalai/gemini/placement_policy.py
colossalai/gemini/placement_policy.py
+1
-1
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+1
-1
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+0
-1
colossalai/zero/utils/zero_hook_v2.py
colossalai/zero/utils/zero_hook_v2.py
+1
-1
tests/test_ddp/test_ddp_ignore_params.py
tests/test_ddp/test_ddp_ignore_params.py
+1
-1
tests/test_ddp/test_ddp_state_dict.py
tests/test_ddp/test_ddp_state_dict.py
+1
-1
tests/test_ddp/test_reducer.py
tests/test_ddp/test_reducer.py
+0
-7
tests/test_tensor/test_chunk.py
tests/test_tensor/test_chunk.py
+1
-1
tests/test_tensor/test_zero_optim.py
tests/test_tensor/test_zero_optim.py
+1
-1
tests/test_zero/test_zero_optim_state_dict.py
tests/test_zero/test_zero_optim_state_dict.py
+1
-2
No files found.
colossalai/gemini/__init__.py
View file @
372f7914
from
.chunk
import
TensorInfo
,
Chunk
,
TensorState
from
.chunk_mgr
import
ChunkManager
from
.stateful_tensor_mgr
import
StatefulTensorMgr
from
.stateful_tensor_mgr
import
StatefulTensorMgr
from
.tensor_placement_policy
import
TensorPlacementPolicyFactory
from
.tensor_placement_policy
import
TensorPlacementPolicyFactory
from
.gemini_mgr
import
GeminiManager
from
.gemini_mgr
import
GeminiManager
__all__
=
[
'StatefulTensorMgr'
,
'TensorPlacementPolicyFactory'
,
'GeminiManager'
]
__all__
=
[
'StatefulTensorMgr'
,
'TensorPlacementPolicyFactory'
,
'GeminiManager'
,
'ChunkManager'
,
'TensorInfo'
,
'Chunk'
,
'TensorState'
]
colossalai/gemini/chunk.py
0 → 100644
View file @
372f7914
import
torch
import
torch.distributed
as
dist
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Optional
,
Dict
,
List
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
colossalai.utils
import
get_current_device
class
TensorState
(
Enum
):
FREE
=
0
COMPUTE
=
1
HOLD
=
2
HOLD_AFTER_BWD
=
3
READY_FOR_REDUCE
=
4
STATE_TRANS
=
((
TensorState
.
FREE
,
TensorState
.
HOLD
),
(
TensorState
.
FREE
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD
,
TensorState
.
FREE
),
(
TensorState
.
HOLD
,
TensorState
.
COMPUTE
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD_AFTER_BWD
),
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
READY_FOR_REDUCE
,
TensorState
.
HOLD
))
@
dataclass
class
TensorInfo
:
state
:
TensorState
offset
:
int
end
:
int
class
ChunkFullError
(
Exception
):
pass
def
is_storage_empty
(
tensor
:
torch
.
Tensor
)
->
bool
:
return
tensor
.
storage
().
size
()
==
0
def
free_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
not
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
0
)
def
alloc_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
tensor
.
numel
())
class
Chunk
:
"""
A chunk is a contiguous memory space which contains multiple tensors.
Args:
chunk_size (int): the number of elements in a chunk
src_rank (int): the process which owns the chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
"""
def
__init__
(
self
,
chunk_size
:
int
,
src_rank
:
int
,
dtype
:
torch
.
dtype
,
init_device
:
Optional
[
torch
.
device
]
=
None
,
force_data_on_cuda
:
bool
=
False
)
->
None
:
self
.
size
=
chunk_size
self
.
utilized_size
=
0
self
.
src_rank
=
src_rank
self
.
is_src_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
src_rank
self
.
global_src_rank
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
DATA
)[
src_rank
]
self
.
dtype
=
dtype
device
=
init_device
or
get_current_device
()
if
force_data_on_cuda
:
self
.
data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
get_current_device
())
self
.
_cpu_data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
)
if
device
.
type
==
'cuda'
:
free_storage
(
self
.
_cpu_data
)
else
:
free_storage
(
self
.
data
)
else
:
self
.
data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
self
.
_cpu_data
=
None
# we only keep the chunk in full in the process by which the tensor is owned
if
not
self
.
is_src_rank
:
free_storage
(
self
.
_payload
)
# each tensor is associated with a TensorInfo to track meta info
self
.
tensors_info
:
Dict
[
torch
.
Tensor
,
TensorInfo
]
=
{}
self
.
mem
=
self
.
size
*
self
.
data
.
element_size
()
def
append
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
"""
Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
assert
tensor
.
dtype
==
self
.
dtype
new_utilized_size
=
self
.
utilized_size
+
tensor
.
numel
()
# raise exception when the chunk size is exceeded
if
new_utilized_size
>
self
.
size
:
raise
ChunkFullError
# set tensor state
tensor_state
=
TensorState
.
FREE
# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
if
self
.
is_src_rank
:
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
flatten
())
tensor_state
=
TensorState
.
HOLD
assert
type
(
self
.
_payload
)
==
torch
.
Tensor
,
"copy_tensor_to_chunk_slice must use a torch tensor"
tensor
.
data
=
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
else
:
tensor
.
storage
().
resize_
(
0
)
self
.
tensors_info
[
tensor
]
=
TensorInfo
(
tensor_state
,
self
.
utilized_size
,
new_utilized_size
)
self
.
utilized_size
=
new_utilized_size
def
release
(
self
)
->
None
:
"""
Release the memory space on processes which do not own the chunk.
"""
if
not
self
.
is_src_rank
:
free_storage
(
self
.
_payload
)
self
.
_update_tensors_state
(
TensorState
.
FREE
)
def
_update_tensors_ptr
(
self
)
->
None
:
assert
type
(
self
.
_payload
)
==
torch
.
Tensor
for
tensor
,
tensor_info
in
self
.
tensors_info
.
items
():
tensor
.
data
=
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
def
_update_tensors_state
(
self
,
next_state
:
TensorState
,
prev_state
:
Optional
[
TensorState
]
=
None
):
for
tensor_info
in
self
.
tensors_info
.
values
():
if
prev_state
is
None
or
tensor_info
.
state
==
prev_state
:
tensor_info
.
state
=
next_state
def
access
(
self
)
->
None
:
"""
Broadcast the chunk to synchronize the tensors across data parallel processes.
"""
# recover the chunk on non-owner processes
# and broadcast the chunk from the source to all processes
if
not
self
.
is_src_rank
:
alloc_storage
(
self
.
_payload
)
self
.
move_device
(
get_current_device
(),
update_ptr
=
False
)
dist
.
broadcast
(
self
.
data
,
self
.
global_src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
# update tensor meta info
self
.
_update_tensors_ptr
()
if
not
self
.
is_src_rank
:
self
.
_update_tensors_state
(
TensorState
.
HOLD
,
prev_state
=
TensorState
.
FREE
)
def
move_device
(
self
,
device
:
torch
.
device
,
update_ptr
:
bool
=
True
)
->
None
:
"""
Move the chunk to a target device.
Args:
device (torch.device): the target device for data movement.
"""
if
self
.
_payload
.
device
==
device
:
return
if
self
.
_cpu_data
is
None
:
self
.
data
.
data
=
self
.
data
.
to
(
device
)
else
:
if
device
.
type
==
'cuda'
:
# cpu -> cuda
src
=
self
.
_cpu_data
dest
=
self
.
data
else
:
# cuda -> cpu
src
=
self
.
data
dest
=
self
.
_cpu_data
alloc_storage
(
dest
)
dest
.
copy_
(
src
)
free_storage
(
src
)
if
update_ptr
:
self
.
_update_tensors_ptr
()
def
reduce
(
self
,
is_all_reduce
:
bool
=
False
)
->
None
:
"""
Reduce or all-reduce the chunk.
Args:
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
"""
self
.
move_device
(
get_current_device
(),
update_ptr
=
False
)
if
is_all_reduce
:
dist
.
all_reduce
(
self
.
data
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
else
:
dist
.
reduce
(
self
.
data
,
self
.
global_src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
self
.
_update_tensors_ptr
()
self
.
_update_tensors_state
(
TensorState
.
HOLD
)
def
tensor_trans_state
(
self
,
tensor
:
torch
.
Tensor
,
tensor_state
:
TensorState
)
->
None
:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
assert
tensor
!=
TensorState
.
FREE
,
'Can only set a chunk of tensors to FREE'
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if
(
self
.
tensors_info
[
tensor
].
state
,
tensor_state
)
not
in
STATE_TRANS
:
# print(
# f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self
.
tensors_info
[
tensor
].
state
=
tensor_state
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data_slice
:
torch
.
Tensor
)
->
None
:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
tensor_info
=
self
.
tensors_info
[
tensor
]
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
flatten
())
tensor
.
data
=
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
@
property
def
can_release
(
self
)
->
bool
:
"""
Check whether the chunk can be released.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
!=
TensorState
.
HOLD
:
return
False
return
True
@
property
def
can_move_device
(
self
)
->
bool
:
"""
Check whether the chunk can be moved across devices.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
in
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
):
return
False
return
True
@
property
def
can_reduce
(
self
)
->
bool
:
"""
Check whether the chunk can be reduced.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
!=
TensorState
.
READY_FOR_REDUCE
:
return
False
return
True
@
property
def
is_empty
(
self
)
->
bool
:
"""
Check whether the chunk is empty.
"""
return
is_storage_empty
(
self
.
_payload
)
def
__repr__
(
self
)
->
str
:
return
f
'Chunk: src rank=
{
self
.
src_rank
}
,size=
{
self
.
size
}
, utilization=
{
self
.
utilized_size
/
self
.
size
*
100
:.
2
f
}
%, freed=
{
self
.
is_empty
}
, tensor states=
{
[
info
.
state
.
name
for
info
in
self
.
tensors_info
.
values
()]
}
'
@
property
def
has_inf_or_nan
(
self
)
->
bool
:
"""
Check if the chunk has inf or nan values.
"""
return
torch
.
isinf
(
self
.
_payload
[:
self
.
utilized_size
]).
any
().
item
()
or
\
torch
.
isnan
(
self
.
_payload
[:
self
.
utilized_size
]).
any
().
item
()
def
copy_
(
self
,
dest_chunk
:
'Chunk'
):
"""
Copy the data of this chunk to a destination chunk.
"""
assert
not
self
.
is_empty
assert
not
dest_chunk
.
is_empty
assert
self
.
size
==
dest_chunk
.
size
assert
self
.
utilized_size
==
dest_chunk
.
utilized_size
self
.
_payload
.
copy_
(
dest_chunk
.
_payload
)
self
.
_update_tensors_ptr
()
@
property
def
device_type
(
self
)
->
str
:
"""
Get the device type of the chunk.
"""
return
self
.
_payload
.
device
.
type
def
__hash__
(
self
)
->
int
:
return
hash
(
id
(
self
))
def
__eq__
(
self
,
__o
:
object
)
->
bool
:
return
self
is
__o
def
get_tensors
(
self
)
->
List
[
torch
.
Tensor
]:
return
list
(
self
.
tensors_info
.
keys
())
@
property
def
_payload
(
self
)
->
torch
.
Tensor
:
if
self
.
_cpu_data
is
None
or
is_storage_empty
(
self
.
_cpu_data
):
return
self
.
data
return
self
.
_cpu_data
colossalai/
tensor
/chunk.py
→
colossalai/
gemini
/chunk
_mgr
.py
View file @
372f7914
import
torch
import
torch
import
torch.distributed
as
dist
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Optional
,
Dict
,
Deque
,
Set
,
List
,
Tuple
,
Iterable
from
typing
import
Optional
,
Dict
,
Deque
,
Set
,
List
,
Tuple
,
Iterable
from
collections
import
deque
from
collections
import
deque
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
.chunk
import
Chunk
,
ChunkFullError
,
TensorState
class
TensorState
(
Enum
):
FREE
=
0
COMPUTE
=
1
HOLD
=
2
HOLD_AFTER_BWD
=
3
READY_FOR_REDUCE
=
4
STATE_TRANS
=
((
TensorState
.
FREE
,
TensorState
.
HOLD
),
(
TensorState
.
FREE
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD
,
TensorState
.
FREE
),
(
TensorState
.
HOLD
,
TensorState
.
COMPUTE
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD
),
(
TensorState
.
COMPUTE
,
TensorState
.
HOLD_AFTER_BWD
),
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
COMPUTE
),
(
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
READY_FOR_REDUCE
),
(
TensorState
.
READY_FOR_REDUCE
,
TensorState
.
HOLD
))
@
dataclass
class
TensorInfo
:
state
:
TensorState
offset
:
int
end
:
int
class
ChunkFullError
(
Exception
):
pass
def
is_storage_empty
(
tensor
:
torch
.
Tensor
)
->
bool
:
return
tensor
.
storage
().
size
()
==
0
def
free_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
not
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
0
)
def
alloc_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
tensor
.
numel
())
class
Chunk
:
"""
A chunk is a contiguous memory space which contains multiple tensors.
Args:
chunk_size (int): the number of elements in a chunk
src_rank (int): the process which owns the chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
"""
def
__init__
(
self
,
chunk_size
:
int
,
src_rank
:
int
,
dtype
:
torch
.
dtype
,
init_device
:
Optional
[
torch
.
device
]
=
None
,
force_data_on_cuda
:
bool
=
False
)
->
None
:
self
.
size
=
chunk_size
self
.
utilized_size
=
0
self
.
src_rank
=
src_rank
self
.
is_src_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
src_rank
self
.
global_src_rank
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
DATA
)[
src_rank
]
self
.
dtype
=
dtype
device
=
init_device
or
get_current_device
()
if
force_data_on_cuda
:
self
.
data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
get_current_device
())
self
.
_cpu_data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
)
if
device
.
type
==
'cuda'
:
free_storage
(
self
.
_cpu_data
)
else
:
free_storage
(
self
.
data
)
else
:
self
.
data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
self
.
_cpu_data
=
None
# we only keep the chunk in full in the process by which the tensor is owned
if
not
self
.
is_src_rank
:
free_storage
(
self
.
_payload
)
# each tensor is associated with a TensorInfo to track meta info
self
.
tensors_info
:
Dict
[
torch
.
Tensor
,
TensorInfo
]
=
{}
self
.
mem
=
self
.
size
*
self
.
data
.
element_size
()
def
append
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
"""
Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
assert
tensor
.
dtype
==
self
.
dtype
new_utilized_size
=
self
.
utilized_size
+
tensor
.
numel
()
# raise exception when the chunk size is exceeded
if
new_utilized_size
>
self
.
size
:
raise
ChunkFullError
# set tensor state
tensor_state
=
TensorState
.
FREE
# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
if
self
.
is_src_rank
:
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
flatten
())
tensor_state
=
TensorState
.
HOLD
assert
type
(
self
.
_payload
)
==
torch
.
Tensor
,
"copy_tensor_to_chunk_slice must use a torch tensor"
tensor
.
data
=
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
else
:
tensor
.
storage
().
resize_
(
0
)
self
.
tensors_info
[
tensor
]
=
TensorInfo
(
tensor_state
,
self
.
utilized_size
,
new_utilized_size
)
self
.
utilized_size
=
new_utilized_size
def
release
(
self
)
->
None
:
"""
Release the memory space on processes which do not own the chunk.
"""
if
not
self
.
is_src_rank
:
free_storage
(
self
.
_payload
)
self
.
_update_tensors_state
(
TensorState
.
FREE
)
def
_update_tensors_ptr
(
self
)
->
None
:
assert
type
(
self
.
_payload
)
==
torch
.
Tensor
for
tensor
,
tensor_info
in
self
.
tensors_info
.
items
():
tensor
.
data
=
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
def
_update_tensors_state
(
self
,
next_state
:
TensorState
,
prev_state
:
Optional
[
TensorState
]
=
None
):
for
tensor_info
in
self
.
tensors_info
.
values
():
if
prev_state
is
None
or
tensor_info
.
state
==
prev_state
:
tensor_info
.
state
=
next_state
def
access
(
self
)
->
None
:
"""
Broadcast the chunk to synchronize the tensors across data parallel processes.
"""
# recover the chunk on non-owner processes
# and broadcast the chunk from the source to all processes
if
not
self
.
is_src_rank
:
alloc_storage
(
self
.
_payload
)
self
.
move_device
(
get_current_device
(),
update_ptr
=
False
)
dist
.
broadcast
(
self
.
data
,
self
.
global_src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
# update tensor meta info
self
.
_update_tensors_ptr
()
if
not
self
.
is_src_rank
:
self
.
_update_tensors_state
(
TensorState
.
HOLD
,
prev_state
=
TensorState
.
FREE
)
def
move_device
(
self
,
device
:
torch
.
device
,
update_ptr
:
bool
=
True
)
->
None
:
"""
Move the chunk to a target device.
Args:
device (torch.device): the target device for data movement.
"""
if
self
.
_payload
.
device
==
device
:
return
if
self
.
_cpu_data
is
None
:
self
.
data
.
data
=
self
.
data
.
to
(
device
)
else
:
if
device
.
type
==
'cuda'
:
# cpu -> cuda
src
=
self
.
_cpu_data
dest
=
self
.
data
else
:
# cuda -> cpu
src
=
self
.
data
dest
=
self
.
_cpu_data
alloc_storage
(
dest
)
dest
.
copy_
(
src
)
free_storage
(
src
)
if
update_ptr
:
self
.
_update_tensors_ptr
()
def
reduce
(
self
,
is_all_reduce
:
bool
=
False
)
->
None
:
"""
Reduce or all-reduce the chunk.
Args:
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
"""
self
.
move_device
(
get_current_device
(),
update_ptr
=
False
)
if
is_all_reduce
:
dist
.
all_reduce
(
self
.
data
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
else
:
dist
.
reduce
(
self
.
data
,
self
.
global_src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
self
.
_update_tensors_ptr
()
self
.
_update_tensors_state
(
TensorState
.
HOLD
)
def
tensor_trans_state
(
self
,
tensor
:
torch
.
Tensor
,
tensor_state
:
TensorState
)
->
None
:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
assert
tensor
!=
TensorState
.
FREE
,
'Can only set a chunk of tensors to FREE'
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if
(
self
.
tensors_info
[
tensor
].
state
,
tensor_state
)
not
in
STATE_TRANS
:
# print(
# f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self
.
tensors_info
[
tensor
].
state
=
tensor_state
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data_slice
:
torch
.
Tensor
)
->
None
:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
tensor_info
=
self
.
tensors_info
[
tensor
]
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
flatten
())
tensor
.
data
=
self
.
_payload
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
@
property
def
can_release
(
self
)
->
bool
:
"""
Check whether the chunk can be released.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
!=
TensorState
.
HOLD
:
return
False
return
True
@
property
def
can_move_device
(
self
)
->
bool
:
"""
Check whether the chunk can be moved across devices.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
in
(
TensorState
.
COMPUTE
,
TensorState
.
READY_FOR_REDUCE
):
return
False
return
True
@
property
def
can_reduce
(
self
)
->
bool
:
"""
Check whether the chunk can be reduced.
"""
for
tensor_info
in
self
.
tensors_info
.
values
():
if
tensor_info
.
state
!=
TensorState
.
READY_FOR_REDUCE
:
return
False
return
True
@
property
def
is_empty
(
self
)
->
bool
:
"""
Check whether the chunk is empty.
"""
return
is_storage_empty
(
self
.
_payload
)
def
__repr__
(
self
)
->
str
:
return
f
'Chunk: src rank=
{
self
.
src_rank
}
,size=
{
self
.
size
}
, utilization=
{
self
.
utilized_size
/
self
.
size
*
100
:.
2
f
}
%, freed=
{
self
.
is_empty
}
, tensor states=
{
[
info
.
state
.
name
for
info
in
self
.
tensors_info
.
values
()]
}
'
@
property
def
has_inf_or_nan
(
self
)
->
bool
:
"""
Check if the chunk has inf or nan values.
"""
return
torch
.
isinf
(
self
.
_payload
[:
self
.
utilized_size
]).
any
().
item
()
or
\
torch
.
isnan
(
self
.
_payload
[:
self
.
utilized_size
]).
any
().
item
()
def
copy_
(
self
,
dest_chunk
:
'Chunk'
):
"""
Copy the data of this chunk to a destination chunk.
"""
assert
not
self
.
is_empty
assert
not
dest_chunk
.
is_empty
assert
self
.
size
==
dest_chunk
.
size
assert
self
.
utilized_size
==
dest_chunk
.
utilized_size
self
.
_payload
.
copy_
(
dest_chunk
.
_payload
)
self
.
_update_tensors_ptr
()
@
property
def
device_type
(
self
)
->
str
:
"""
Get the device type of the chunk.
"""
return
self
.
_payload
.
device
.
type
def
__hash__
(
self
)
->
int
:
return
hash
(
id
(
self
))
def
__eq__
(
self
,
__o
:
object
)
->
bool
:
return
self
is
__o
def
get_tensors
(
self
)
->
List
[
torch
.
Tensor
]:
return
list
(
self
.
tensors_info
.
keys
())
@
property
def
_payload
(
self
)
->
torch
.
Tensor
:
if
self
.
_cpu_data
is
None
or
is_storage_empty
(
self
.
_cpu_data
):
return
self
.
data
return
self
.
_cpu_data
class
ChunkManager
:
class
ChunkManager
:
...
...
colossalai/gemini/gemini_mgr.py
View file @
372f7914
...
@@ -3,8 +3,8 @@ import functools
...
@@ -3,8 +3,8 @@ import functools
from
.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
from
time
import
time
from
time
import
time
from
colossalai.
tensor.chunk
import
Chunk
,
ChunkManager
from
colossalai.
gemini
import
Chunk
,
ChunkManager
from
.placement_policy
import
PlacementPolicy
,
PlacementPolicyFactory
from
.placement_policy
import
PlacementPolicyFactory
class
GeminiManager
:
class
GeminiManager
:
...
...
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
372f7914
...
@@ -2,7 +2,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
...
@@ -2,7 +2,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from
colossalai.utils.memory
import
colo_device_memory_used
,
colo_device_memory_capacity
from
colossalai.utils.memory
import
colo_device_memory_used
,
colo_device_memory_capacity
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
from
colossalai.
tensor
import
ChunkManager
from
colossalai.
gemini
import
ChunkManager
import
torch
import
torch
import
time
import
time
...
...
colossalai/gemini/placement_policy.py
View file @
372f7914
...
@@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity
...
@@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity
from
colossalai.gemini.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
colossalai.gemini.memory_tracer.memstats_collector
import
MemStatsCollectorV2
from
typing
import
Type
from
typing
import
Type
import
functools
import
functools
from
colossalai.
tensor.chunk
import
Chunk
,
ChunkManager
from
colossalai.
gemini
import
Chunk
,
ChunkManager
class
PlacementPolicy
(
ABC
):
class
PlacementPolicy
(
ABC
):
...
...
colossalai/nn/parallel/data_parallel.py
View file @
372f7914
...
@@ -5,7 +5,7 @@ from colossalai.core import global_context as gpc
...
@@ -5,7 +5,7 @@ from colossalai.core import global_context as gpc
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
functools
import
partial
from
functools
import
partial
from
colossalai.zero.utils.zero_hook_v2
import
ZeROHookV2
from
colossalai.zero.utils.zero_hook_v2
import
ZeROHookV2
from
colossalai.
tensor
.chunk
import
TensorState
,
Chunk
from
colossalai.
gemini
.chunk
import
TensorState
,
Chunk
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
typing
import
Dict
,
Iterable
,
List
,
Optional
from
typing
import
Dict
,
Iterable
,
List
,
Optional
...
...
colossalai/tensor/__init__.py
View file @
372f7914
...
@@ -5,7 +5,6 @@ from .colo_parameter import ColoParameter
...
@@ -5,7 +5,6 @@ from .colo_parameter import ColoParameter
from
.utils
import
convert_parameter
,
named_params_with_colotensor
from
.utils
import
convert_parameter
,
named_params_with_colotensor
from
.dist_spec_mgr
import
DistSpecManager
from
.dist_spec_mgr
import
DistSpecManager
from
.param_op_hook
import
ParamOpHook
,
ParamOpHookManager
from
.param_op_hook
import
ParamOpHook
,
ParamOpHookManager
from
.chunk
import
ChunkManager
,
TensorState
from
.
import
distspec
from
.
import
distspec
from
.process_group
import
ProcessGroup
from
.process_group
import
ProcessGroup
...
...
colossalai/zero/utils/zero_hook_v2.py
View file @
372f7914
import
torch
import
torch
from
colossalai.tensor.param_op_hook
import
ParamOpHook
from
colossalai.tensor.param_op_hook
import
ParamOpHook
from
colossalai.
tensor.chunk
import
ChunkManager
,
TensorState
from
colossalai.
gemini
import
TensorState
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
List
from
typing
import
List
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
...
tests/test_ddp/test_ddp_ignore_params.py
View file @
372f7914
...
@@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
...
@@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.
tensor
import
ChunkManager
from
colossalai.
gemini
import
ChunkManager
from
functools
import
partial
from
functools
import
partial
from
colossalai.nn.parallel
import
ColoDDP
,
ZeroDDP
from
colossalai.nn.parallel
import
ColoDDP
,
ZeroDDP
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
...
...
tests/test_ddp/test_ddp_state_dict.py
View file @
372f7914
...
@@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
...
@@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.
tensor
import
ChunkManager
from
colossalai.
gemini
import
ChunkManager
from
functools
import
partial
from
functools
import
partial
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel
import
ZeroDDP
,
ColoDDP
from
colossalai.nn.parallel
import
ZeroDDP
,
ColoDDP
...
...
tests/test_ddp/test_reducer.py
View file @
372f7914
...
@@ -5,14 +5,7 @@ import torch.multiprocessing as mp
...
@@ -5,14 +5,7 @@ import torch.multiprocessing as mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ChunkManager
from
functools
import
partial
from
functools
import
partial
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel
import
ZeroDDP
,
ColoDDP
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
typing
import
Callable
from
collections
import
OrderedDict
from
colossalai.nn.parallel.reducer
import
Reducer
from
colossalai.nn.parallel.reducer
import
Reducer
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed.distributed_c10d
import
_get_default_group
from
torch.distributed.distributed_c10d
import
_get_default_group
...
...
tests/test_tensor/test_chunk.py
View file @
372f7914
...
@@ -4,7 +4,7 @@ import pytest
...
@@ -4,7 +4,7 @@ import pytest
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
typing
import
List
from
typing
import
List
from
functools
import
partial
from
functools
import
partial
from
colossalai.
tensor
import
ChunkManager
from
colossalai.
gemini
import
ChunkManager
from
colossalai.testing
import
rerun_if_address_is_in_use
,
parameterize
from
colossalai.testing
import
rerun_if_address_is_in_use
,
parameterize
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
...
...
tests/test_tensor/test_zero_optim.py
View file @
372f7914
...
@@ -7,7 +7,7 @@ from colossalai.testing import rerun_if_address_is_in_use
...
@@ -7,7 +7,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.
tensor
import
ChunkManager
from
colossalai.
gemini
import
ChunkManager
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
functools
import
partial
from
functools
import
partial
from
_utils
import
tensor_equal
,
set_seed
,
tensor_shard_equal
from
_utils
import
tensor_equal
,
set_seed
,
tensor_shard_equal
...
...
tests/test_zero/test_zero_optim_state_dict.py
View file @
372f7914
...
@@ -7,13 +7,12 @@ from colossalai.testing import rerun_if_address_is_in_use
...
@@ -7,13 +7,12 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ChunkManager
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
functools
import
partial
from
functools
import
partial
from
tests.test_tensor._utils
import
set_seed
from
tests.test_tensor._utils
import
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel.data_parallel
import
ZeroDDP
from
colossalai.nn.parallel.data_parallel
import
ZeroDDP
from
colossalai.gemini
import
GeminiManager
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
from
colossalai.testing
import
parameterize
from
colossalai.testing
import
parameterize
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.zero
import
ZeroOptimizer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment