Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e99edfcb
"container/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "2153ee81759d45e27ff6846c0195b8e8029c2529"
Unverified
Commit
e99edfcb
authored
Dec 12, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 12, 2022
Browse files
[NFC] polish comments for Chunk class (#2116)
parent
09d69e1c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
94 additions
and
82 deletions
+94
-82
colossalai/gemini/chunk/chunk.py
colossalai/gemini/chunk/chunk.py
+52
-42
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+2
-2
colossalai/nn/parallel/utils.py
colossalai/nn/parallel/utils.py
+21
-20
colossalai/tensor/param_op_hook.py
colossalai/tensor/param_op_hook.py
+6
-4
colossalai/zero/utils/zero_hook.py
colossalai/zero/utils/zero_hook.py
+6
-7
tests/test_gemini/update/test_chunkv2.py
tests/test_gemini/update/test_chunkv2.py
+7
-7
No files found.
colossalai/gemini/chunk/chunk.py
View file @
e99edfcb
...
@@ -71,8 +71,9 @@ class Chunk:
...
@@ -71,8 +71,9 @@ class Chunk:
chunk_size (int): the number of elements in the chunk
chunk_size (int): the number of elements in the chunk
process_group (ColoProcessGroup): the process group of this chunk
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional,
the device
where the tensor is
initializ
ed
init_device (torch.device): optional,
During the chunk construction process,
where the tensor is
stor
ed
.
The default value is None, which is the current GPU
The default value is None, which is the current GPU
cpu_shard_init (bool): a flag indicates the local chunk shard is resident on CPU.
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
"""
"""
...
@@ -81,13 +82,12 @@ class Chunk:
...
@@ -81,13 +82,12 @@ class Chunk:
self
.
chunk_size
=
chunk_size
self
.
chunk_size
=
chunk_size
self
.
utilized_size
=
0
self
.
utilized_size
=
0
# Here, we use torch process group,
# since ColoProcessGroup might get deprecated soon
self
.
torch_pg
=
process_group
.
dp_process_group
()
self
.
torch_pg
=
process_group
.
dp_process_group
()
self
.
pg_size
=
dist
.
get_world_size
(
self
.
torch_pg
)
self
.
pg_size
=
dist
.
get_world_size
(
self
.
torch_pg
)
self
.
pg_rank
=
dist
.
get_rank
(
self
.
torch_pg
)
self
.
pg_rank
=
dist
.
get_rank
(
self
.
torch_pg
)
# the chunk size should be
able to be divied by the size of GPU
# the chunk size should be
divisible by the dp degree
if
not
keep_gathered
:
if
not
keep_gathered
:
assert
chunk_size
%
self
.
pg_size
==
0
assert
chunk_size
%
self
.
pg_size
==
0
self
.
shard_size
=
chunk_size
//
self
.
pg_size
self
.
shard_size
=
chunk_size
//
self
.
pg_size
...
@@ -97,13 +97,21 @@ class Chunk:
...
@@ -97,13 +97,21 @@ class Chunk:
self
.
dtype
=
dtype
self
.
dtype
=
dtype
device
=
init_device
or
get_current_device
()
device
=
init_device
or
get_current_device
()
# chunk_temp is a global chunk, which only exists during building the chunks.
self
.
chunk_temp
=
torch
.
zeros
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
# keep all zero
self
.
chunk_temp
=
torch
.
zeros
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
# keep all zero
self
.
chunk_total
=
None
# we force chunk_total located in CUDA
self
.
cuda_shard
=
None
# using two attributes for the better interpretation
self
.
cuda_global_chunk
=
None
# we force cuda_global_chunk located in CUDA
# cuda local chunk, which is sharded on GPUs
self
.
cuda_shard
=
None
# cpu local chunk, which is sharded on CPUs
self
.
cpu_shard
=
None
self
.
cpu_shard
=
None
# is the chunks gathers, which means chunks are duplicated on each process,
# and we should use the cuda_global_chunk.
self
.
is_gathered
=
True
self
.
is_gathered
=
True
# configure the init de
i
vce of the shard
# configure the init dev
i
ce of the shard
# no-offload default: fp16, fp32 -> CUDA
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
# offload default: fp16, fp32 -> CPU
self
.
shard_device
=
torch
.
device
(
"cpu"
)
if
cpu_shard_init
else
get_current_device
()
self
.
shard_device
=
torch
.
device
(
"cpu"
)
if
cpu_shard_init
else
get_current_device
()
...
@@ -111,17 +119,19 @@ class Chunk:
...
@@ -111,17 +119,19 @@ class Chunk:
self
.
chunk_mem
=
self
.
chunk_size
*
self
.
chunk_temp
.
element_size
()
self
.
chunk_mem
=
self
.
chunk_size
*
self
.
chunk_temp
.
element_size
()
self
.
shard_mem
=
self
.
chunk_mem
//
self
.
pg_size
self
.
shard_mem
=
self
.
chunk_mem
//
self
.
pg_size
# each tensor is associated with a TensorInfo to track meta info
# each tensor is associated with a TensorInfo to track its meta info
# (state, offset, end)
self
.
tensors_info
:
Dict
[
torch
.
Tensor
,
TensorInfo
]
=
{}
self
.
tensors_info
:
Dict
[
torch
.
Tensor
,
TensorInfo
]
=
{}
# the total number of
all
tensors
# the total number of tensors
in the chunk
self
.
num_tensors
=
0
self
.
num_tensors
=
0
# monitor the states of all tensors
self
.
tensors_state_monitor
:
Dict
[
TensorState
,
int
]
=
dict
()
# Record the number of tensors in different states
self
.
tensor_state_cnter
:
Dict
[
TensorState
,
int
]
=
dict
()
for
state
in
TensorState
:
for
state
in
TensorState
:
self
.
tensor
s
_state_
monito
r
[
state
]
=
0
self
.
tensor_state_
cnte
r
[
state
]
=
0
#
some
chunk
s can ke
ep gathered
all the time
#
If a
chunk
is k
ep
t
gathered
,
#
so their computation patterns are
the same as that of the parameters in DDP
#
they are treated
the same as that of the parameters in DDP
during training.
self
.
keep_gathered
=
keep_gathered
self
.
keep_gathered
=
keep_gathered
if
self
.
keep_gathered
:
if
self
.
keep_gathered
:
pin_memory
=
False
# since this chunk is gathered, it doesn't need to pin
pin_memory
=
False
# since this chunk is gathered, it doesn't need to pin
...
@@ -182,7 +192,7 @@ class Chunk:
...
@@ -182,7 +192,7 @@ class Chunk:
assert
self
.
chunk_temp
is
None
assert
self
.
chunk_temp
is
None
if
self
.
is_gathered
:
if
self
.
is_gathered
:
return
self
.
c
hunk_total
return
self
.
c
uda_global_chunk
elif
self
.
cuda_shard
is
not
None
:
elif
self
.
cuda_shard
is
not
None
:
return
self
.
cuda_shard
return
self
.
cuda_shard
else
:
else
:
...
@@ -207,19 +217,19 @@ class Chunk:
...
@@ -207,19 +217,19 @@ class Chunk:
if
self
.
keep_gathered
:
if
self
.
keep_gathered
:
return
False
return
False
else
:
else
:
return
self
.
tensor
s
_state_
monito
r
[
TensorState
.
HOLD
]
+
\
return
self
.
tensor_state_
cnte
r
[
TensorState
.
HOLD
]
+
\
self
.
tensor
s
_state_
monito
r
[
TensorState
.
HOLD_AFTER_BWD
]
==
self
.
num_tensors
self
.
tensor_state_
cnte
r
[
TensorState
.
HOLD_AFTER_BWD
]
==
self
.
num_tensors
@
property
@
property
def
can_reduce
(
self
):
def
can_reduce
(
self
):
return
self
.
tensor
s
_state_
monito
r
[
TensorState
.
READY_FOR_REDUCE
]
==
self
.
num_tensors
return
self
.
tensor_state_
cnte
r
[
TensorState
.
READY_FOR_REDUCE
]
==
self
.
num_tensors
@
property
@
property
def
has_inf_or_nan
(
self
)
->
bool
:
def
has_inf_or_nan
(
self
)
->
bool
:
"""Check if the chunk has inf or nan values on CUDA.
"""Check if the chunk has inf or nan values on CUDA.
"""
"""
if
self
.
is_gathered
:
if
self
.
is_gathered
:
valid_tensor
=
self
.
c
hunk_total
[:
self
.
utilized_size
]
valid_tensor
=
self
.
c
uda_global_chunk
[:
self
.
utilized_size
]
else
:
else
:
assert
self
.
cuda_shard
is
not
None
# only check on CUDA
assert
self
.
cuda_shard
is
not
None
# only check on CUDA
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
...
@@ -231,7 +241,7 @@ class Chunk:
...
@@ -231,7 +241,7 @@ class Chunk:
"""
"""
assert
self
.
l2_norm
is
None
,
"you are calculating the l2 norm twice"
assert
self
.
l2_norm
is
None
,
"you are calculating the l2 norm twice"
if
self
.
is_gathered
:
if
self
.
is_gathered
:
valid_tensor
=
self
.
c
hunk_total
[:
self
.
utilized_size
]
valid_tensor
=
self
.
c
uda_global_chunk
[:
self
.
utilized_size
]
else
:
else
:
assert
self
.
cuda_shard
is
not
None
# calculate on CUDA
assert
self
.
cuda_shard
is
not
None
# calculate on CUDA
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
...
@@ -261,7 +271,7 @@ class Chunk:
...
@@ -261,7 +271,7 @@ class Chunk:
self
.
num_tensors
+=
1
self
.
num_tensors
+=
1
tensor_state
=
TensorState
.
HOLD
tensor_state
=
TensorState
.
HOLD
self
.
tensors_info
[
tensor
]
=
TensorInfo
(
tensor_state
,
self
.
utilized_size
,
new_utilized_size
)
self
.
tensors_info
[
tensor
]
=
TensorInfo
(
tensor_state
,
self
.
utilized_size
,
new_utilized_size
)
self
.
tensor
s
_state_
monito
r
[
tensor_state
]
+=
1
self
.
tensor_state_
cnte
r
[
tensor_state
]
+=
1
self
.
utilized_size
=
new_utilized_size
self
.
utilized_size
=
new_utilized_size
def
close_chunk
(
self
):
def
close_chunk
(
self
):
...
@@ -277,10 +287,10 @@ class Chunk:
...
@@ -277,10 +287,10 @@ class Chunk:
self
.
valid_end
=
self
.
utilized_size
-
self
.
shard_begin
self
.
valid_end
=
self
.
utilized_size
-
self
.
shard_begin
if
self
.
chunk_temp
.
device
.
type
==
'cpu'
:
if
self
.
chunk_temp
.
device
.
type
==
'cpu'
:
self
.
c
hunk_total
=
self
.
chunk_temp
.
to
(
get_current_device
())
self
.
c
uda_global_chunk
=
self
.
chunk_temp
.
to
(
get_current_device
())
self
.
__update_tensors_ptr
()
self
.
__update_tensors_ptr
()
else
:
else
:
self
.
c
hunk_total
=
self
.
chunk_temp
self
.
c
uda_global_chunk
=
self
.
chunk_temp
self
.
chunk_temp
=
None
self
.
chunk_temp
=
None
self
.
__scatter
()
self
.
__scatter
()
...
@@ -366,19 +376,19 @@ class Chunk:
...
@@ -366,19 +376,19 @@ class Chunk:
if
self
.
pg_size
==
1
:
if
self
.
pg_size
==
1
:
# tricky code here
# tricky code here
# just move c
hunk_total
to cuda_shard
# just move c
uda_global_chunk
to cuda_shard
# the communication is not necessary
# the communication is not necessary
self
.
__scatter
()
self
.
__scatter
()
elif
self
.
keep_gathered
:
elif
self
.
keep_gathered
:
# we use all-reduce here
# we use all-reduce here
dist
.
all_reduce
(
self
.
c
hunk_total
,
group
=
self
.
torch_pg
)
dist
.
all_reduce
(
self
.
c
uda_global_chunk
,
group
=
self
.
torch_pg
)
else
:
else
:
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
get_current_device
())
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
get_current_device
())
input_list
=
list
(
torch
.
chunk
(
self
.
c
hunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
input_list
=
list
(
torch
.
chunk
(
self
.
c
uda_global_chunk
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
reduce_scatter
(
self
.
cuda_shard
,
input_list
,
group
=
self
.
torch_pg
)
dist
.
reduce_scatter
(
self
.
cuda_shard
,
input_list
,
group
=
self
.
torch_pg
)
free_storage
(
self
.
c
hunk_total
)
free_storage
(
self
.
c
uda_global_chunk
)
self
.
is_gathered
=
False
self
.
is_gathered
=
False
self
.
__update_tensors_state
(
TensorState
.
HOLD
)
self
.
__update_tensors_state
(
TensorState
.
HOLD
)
...
@@ -413,8 +423,8 @@ class Chunk:
...
@@ -413,8 +423,8 @@ class Chunk:
assert
self
.
is_gathered
assert
self
.
is_gathered
tensor_info
=
self
.
tensors_info
[
tensor
]
tensor_info
=
self
.
tensors_info
[
tensor
]
self
.
c
hunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
data
.
flatten
())
self
.
c
uda_global_chunk
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
data
.
flatten
())
tensor
.
data
=
self
.
c
hunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
tensor
.
data
=
self
.
c
uda_global_chunk
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
def
get_valid_length
(
self
)
->
int
:
def
get_valid_length
(
self
)
->
int
:
"""Get the valid length of the chunk's payload.
"""Get the valid length of the chunk's payload.
...
@@ -443,7 +453,7 @@ class Chunk:
...
@@ -443,7 +453,7 @@ class Chunk:
friend_chunk
=
self
.
paired_chunk
friend_chunk
=
self
.
paired_chunk
if
self
.
is_gathered
is
True
:
if
self
.
is_gathered
is
True
:
assert
friend_chunk
.
is_gathered
is
True
assert
friend_chunk
.
is_gathered
is
True
self
.
c
hunk_total
.
copy_
(
friend_chunk
.
c
hunk_total
)
self
.
c
uda_global_chunk
.
copy_
(
friend_chunk
.
c
uda_global_chunk
)
self
.
optim_sync_flag
=
True
self
.
optim_sync_flag
=
True
elif
friend_chunk
.
device_type
==
'cuda'
and
self
.
device_type
==
'cuda'
:
elif
friend_chunk
.
device_type
==
'cuda'
and
self
.
device_type
==
'cuda'
:
self
.
cuda_shard
.
copy_
(
friend_chunk
.
cuda_shard
)
self
.
cuda_shard
.
copy_
(
friend_chunk
.
cuda_shard
)
...
@@ -465,8 +475,8 @@ class Chunk:
...
@@ -465,8 +475,8 @@ class Chunk:
# sanity check
# sanity check
assert
self
.
cuda_shard
is
not
None
assert
self
.
cuda_shard
is
not
None
alloc_storage
(
self
.
c
hunk_total
)
alloc_storage
(
self
.
c
uda_global_chunk
)
gather_list
=
list
(
torch
.
chunk
(
input
=
self
.
c
hunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
gather_list
=
list
(
torch
.
chunk
(
input
=
self
.
c
uda_global_chunk
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
all_gather
(
gather_list
,
self
.
cuda_shard
,
self
.
torch_pg
)
dist
.
all_gather
(
gather_list
,
self
.
cuda_shard
,
self
.
torch_pg
)
self
.
cuda_shard
=
None
self
.
cuda_shard
=
None
...
@@ -480,11 +490,11 @@ class Chunk:
...
@@ -480,11 +490,11 @@ class Chunk:
# sanity check
# sanity check
assert
self
.
cuda_shard
is
None
assert
self
.
cuda_shard
is
None
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
self
.
c
hunk_total
.
device
)
self
.
cuda_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
self
.
c
uda_global_chunk
.
device
)
self
.
cuda_shard
.
copy_
(
self
.
c
hunk_total
[
self
.
shard_begin
:
self
.
shard_end
])
self
.
cuda_shard
.
copy_
(
self
.
c
uda_global_chunk
[
self
.
shard_begin
:
self
.
shard_end
])
free_storage
(
self
.
c
hunk_total
)
free_storage
(
self
.
c
uda_global_chunk
)
self
.
is_gathered
=
False
self
.
is_gathered
=
False
def
__paired_shard_move
(
self
):
def
__paired_shard_move
(
self
):
...
@@ -505,15 +515,15 @@ class Chunk:
...
@@ -505,15 +515,15 @@ class Chunk:
def
__update_tensors_ptr
(
self
)
->
None
:
def
__update_tensors_ptr
(
self
)
->
None
:
# sanity check
# sanity check
assert
self
.
is_gathered
assert
self
.
is_gathered
assert
type
(
self
.
c
hunk_total
)
==
torch
.
Tensor
assert
type
(
self
.
c
uda_global_chunk
)
==
torch
.
Tensor
for
tensor
,
tensor_info
in
self
.
tensors_info
.
items
():
for
tensor
,
tensor_info
in
self
.
tensors_info
.
items
():
tensor
.
data
=
self
.
c
hunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
tensor
.
data
=
self
.
c
uda_global_chunk
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
def
__update_one_tensor_info
(
self
,
tensor_info
:
TensorInfo
,
next_state
:
TensorState
):
def
__update_one_tensor_info
(
self
,
tensor_info
:
TensorInfo
,
next_state
:
TensorState
):
self
.
tensor
s
_state_
monito
r
[
tensor_info
.
state
]
-=
1
self
.
tensor_state_
cnte
r
[
tensor_info
.
state
]
-=
1
tensor_info
.
state
=
next_state
tensor_info
.
state
=
next_state
self
.
tensor
s
_state_
monito
r
[
tensor_info
.
state
]
+=
1
self
.
tensor_state_
cnte
r
[
tensor_info
.
state
]
+=
1
def
__update_tensors_state
(
self
,
next_state
:
TensorState
,
prev_state
:
Optional
[
TensorState
]
=
None
):
def
__update_tensors_state
(
self
,
next_state
:
TensorState
,
prev_state
:
Optional
[
TensorState
]
=
None
):
for
tensor_info
in
self
.
tensors_info
.
values
():
for
tensor_info
in
self
.
tensors_info
.
values
():
...
@@ -543,9 +553,9 @@ class Chunk:
...
@@ -543,9 +553,9 @@ class Chunk:
output
.
append
(
"
\t
chunk temp:
\n
"
)
output
.
append
(
"
\t
chunk temp:
\n
"
)
print_tensor
(
tensor
=
self
.
chunk_temp
,
prefix
=
'
\t\t
'
)
print_tensor
(
tensor
=
self
.
chunk_temp
,
prefix
=
'
\t\t
'
)
if
self
.
c
hunk_total
is
not
None
and
self
.
c
hunk_total
.
storage
().
size
()
>
0
:
if
self
.
c
uda_global_chunk
is
not
None
and
self
.
c
uda_global_chunk
.
storage
().
size
()
>
0
:
output
.
append
(
"
\t
chunk total:
\n
"
)
output
.
append
(
"
\t
chunk total:
\n
"
)
print_tensor
(
tensor
=
self
.
c
hunk_total
,
prefix
=
'
\t\t
'
)
print_tensor
(
tensor
=
self
.
c
uda_global_chunk
,
prefix
=
'
\t\t
'
)
if
self
.
cuda_shard
is
not
None
:
if
self
.
cuda_shard
is
not
None
:
output
.
append
(
"
\t
cuda shard:
\n
"
)
output
.
append
(
"
\t
cuda shard:
\n
"
)
...
@@ -561,6 +571,6 @@ class Chunk:
...
@@ -561,6 +571,6 @@ class Chunk:
if
detailed
:
if
detailed
:
output
.
append
(
"
\t
tensor state monitor:
\n
"
)
output
.
append
(
"
\t
tensor state monitor:
\n
"
)
for
st
in
TensorState
:
for
st
in
TensorState
:
output
.
append
(
"
\t\t
# of {}: {}
\n
"
.
format
(
st
,
self
.
tensor
s
_state_
monito
r
[
st
]))
output
.
append
(
"
\t\t
# of {}: {}
\n
"
.
format
(
st
,
self
.
tensor_state_
cnte
r
[
st
]))
return
''
.
join
(
output
)
return
''
.
join
(
output
)
colossalai/nn/parallel/data_parallel.py
View file @
e99edfcb
...
@@ -299,7 +299,7 @@ class ZeroDDP(ColoDDP):
...
@@ -299,7 +299,7 @@ class ZeroDDP(ColoDDP):
reduced
=
self
.
chunk_manager
.
reduce_chunk
(
chunk
)
reduced
=
self
.
chunk_manager
.
reduce_chunk
(
chunk
)
if
reduced
:
if
reduced
:
if
chunk
.
is_gathered
:
if
chunk
.
is_gathered
:
chunk
.
c
hunk_total
.
div_
(
chunk
.
pg_size
)
chunk
.
c
uda_global_chunk
.
div_
(
chunk
.
pg_size
)
else
:
else
:
chunk
.
cuda_shard
.
div_
(
chunk
.
pg_size
)
chunk
.
cuda_shard
.
div_
(
chunk
.
pg_size
)
# check overflow elements
# check overflow elements
...
@@ -529,7 +529,7 @@ class ZeroDDP(ColoDDP):
...
@@ -529,7 +529,7 @@ class ZeroDDP(ColoDDP):
load
(
parameter_name
,
tensor
,
partial
(
load_fp32_parameter
,
parameter_slice
))
load
(
parameter_name
,
tensor
,
partial
(
load_fp32_parameter
,
parameter_slice
))
if
chunk
.
is_gathered
:
if
chunk
.
is_gathered
:
chunk
.
c
hunk_total
.
copy_
(
temp_chunk
)
chunk
.
c
uda_global_chunk
.
copy_
(
temp_chunk
)
elif
chunk
.
cuda_shard
is
not
None
:
elif
chunk
.
cuda_shard
is
not
None
:
chunk
.
cuda_shard
.
copy_
(
temp_chunk
[
chunk
.
shard_begin
:
chunk
.
shard_end
])
chunk
.
cuda_shard
.
copy_
(
temp_chunk
[
chunk
.
shard_begin
:
chunk
.
shard_end
])
else
:
else
:
...
...
colossalai/nn/parallel/utils.py
View file @
e99edfcb
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.gemini.chunk
import
Chunk
from
colossalai.utils
import
get_current_device
from
colossalai.gemini.chunk
import
Chunk
from
colossalai.utils
import
get_current_device
def
get_temp_total_chunk_on_cuda
(
chunk
:
Chunk
):
if
chunk
.
is_gathered
:
def
get_temp_total_chunk_on_cuda
(
chunk
:
Chunk
):
return
chunk
.
chunk_total
if
chunk
.
is_gathered
:
return
chunk
.
cuda_global_chunk
if
chunk
.
cuda_shard
is
not
None
:
shard_temp
=
chunk
.
cuda_shard
if
chunk
.
cuda_shard
is
not
None
:
else
:
shard_temp
=
chunk
.
cuda_shard
shard_temp
=
chunk
.
cpu_shard
.
to
(
get_current_device
())
else
:
shard_temp
=
chunk
.
cpu_shard
.
to
(
get_current_device
())
total_temp
=
torch
.
zeros
(
chunk
.
chunk_size
,
dtype
=
chunk
.
dtype
,
device
=
get_current_device
())
gather_list
=
list
(
torch
.
chunk
(
input
=
total_temp
,
chunks
=
chunk
.
pg_size
,
dim
=
0
))
total_temp
=
torch
.
zeros
(
chunk
.
chunk_size
,
dtype
=
chunk
.
dtype
,
device
=
get_current_device
())
dist
.
all_gather
(
tensor_list
=
gather_list
,
tensor
=
shard_temp
,
group
=
chunk
.
torch_pg
)
gather_list
=
list
(
torch
.
chunk
(
input
=
total_temp
,
chunks
=
chunk
.
pg_size
,
dim
=
0
))
dist
.
all_gather
(
tensor_list
=
gather_list
,
tensor
=
shard_temp
,
group
=
chunk
.
torch_pg
)
return
total_temp
return
total_temp
colossalai/tensor/param_op_hook.py
View file @
e99edfcb
...
@@ -9,10 +9,11 @@ from colossalai.tensor.tensor_spec import ColoTensorSpec
...
@@ -9,10 +9,11 @@ from colossalai.tensor.tensor_spec import ColoTensorSpec
class
ColoParamOpHook
(
ABC
):
class
ColoParamOpHook
(
ABC
):
"""Hook which is triggered by each operation when operands contain ColoParameter.
"""
Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``.
These four methods take a list
``post_forward``, ``pre_backward`` and ``post_backward``.
of ColoParameter.
These four methods apply a list
of ColoParameter
as input args
.
"""
"""
@
abstractmethod
@
abstractmethod
...
@@ -33,7 +34,8 @@ class ColoParamOpHook(ABC):
...
@@ -33,7 +34,8 @@ class ColoParamOpHook(ABC):
class
ColoParamOpHookManager
:
class
ColoParamOpHookManager
:
"""Manage your param op hooks. It only has static methods.
"""
Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``.
The only static method you should call is ``use_hooks(*hooks)``.
"""
"""
hooks
:
Tuple
[
ColoParamOpHook
,
...]
=
tuple
()
hooks
:
Tuple
[
ColoParamOpHook
,
...]
=
tuple
()
...
...
colossalai/zero/utils/zero_hook.py
View file @
e99edfcb
...
@@ -2,23 +2,22 @@ from typing import Optional
...
@@ -2,23 +2,22 @@ from typing import Optional
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.gemini.memory_tracer
import
MemStatsCollector
from
colossalai.gemini.ophooks
import
BaseOpHook
from
colossalai.gemini.stateful_tensor
import
TensorState
from
colossalai.gemini.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.registry
import
OPHOOKS
from
colossalai.registry
import
OPHOOKS
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.gemini.ophooks
import
BaseOpHook
from
colossalai.gemini.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.gemini.memory_tracer
import
MemStatsCollector
from
colossalai.gemini.stateful_tensor
import
TensorState
@
OPHOOKS
.
register_module
@
OPHOOKS
.
register_module
class
ZeroHook
(
BaseOpHook
):
class
ZeroHook
(
BaseOpHook
):
"""
"""
A hook to process sharded param for ZeRO method.
A hook to process sharded param for ZeRO method.
Warning: this class has been deprecated after version 0.1.12
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
...
tests/test_gemini/update/test_chunkv2.py
View file @
e99edfcb
...
@@ -69,7 +69,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
...
@@ -69,7 +69,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
assert
my_chunk
.
can_move
assert
my_chunk
.
can_move
my_chunk
.
shard_move
(
get_current_device
())
my_chunk
.
shard_move
(
get_current_device
())
else
:
else
:
assert
my_chunk
.
c
hunk_total
.
size
(
0
)
==
1024
assert
my_chunk
.
c
uda_global_chunk
.
size
(
0
)
==
1024
assert
my_chunk
.
device_type
==
'cuda'
assert
my_chunk
.
device_type
==
'cuda'
assert
not
my_chunk
.
can_move
assert
not
my_chunk
.
can_move
...
@@ -82,27 +82,27 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
...
@@ -82,27 +82,27 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
for
param
,
param_cp
in
zip
(
param_list
,
param_cp_list
):
for
param
,
param_cp
in
zip
(
param_list
,
param_cp_list
):
check_euqal
(
param
,
param_cp
)
check_euqal
(
param
,
param_cp
)
assert
my_chunk
.
tensor
s
_state_
monito
r
[
TensorState
.
HOLD
]
==
4
assert
my_chunk
.
tensor_state_
cnte
r
[
TensorState
.
HOLD
]
==
4
my_chunk
.
tensor_trans_state
(
param_list
[
0
],
TensorState
.
COMPUTE
)
my_chunk
.
tensor_trans_state
(
param_list
[
0
],
TensorState
.
COMPUTE
)
assert
my_chunk
.
tensor
s
_state_
monito
r
[
TensorState
.
HOLD
]
==
3
assert
my_chunk
.
tensor_state_
cnte
r
[
TensorState
.
HOLD
]
==
3
assert
my_chunk
.
tensor
s
_state_
monito
r
[
TensorState
.
COMPUTE
]
==
1
assert
my_chunk
.
tensor_state_
cnte
r
[
TensorState
.
COMPUTE
]
==
1
assert
not
my_chunk
.
can_release
assert
not
my_chunk
.
can_release
for
param
in
param_list
:
for
param
in
param_list
:
my_chunk
.
tensor_trans_state
(
param
,
TensorState
.
COMPUTE
)
my_chunk
.
tensor_trans_state
(
param
,
TensorState
.
COMPUTE
)
my_chunk
.
tensor_trans_state
(
param
,
TensorState
.
READY_FOR_REDUCE
)
my_chunk
.
tensor_trans_state
(
param
,
TensorState
.
READY_FOR_REDUCE
)
assert
my_chunk
.
tensor
s
_state_
monito
r
[
TensorState
.
READY_FOR_REDUCE
]
==
4
assert
my_chunk
.
tensor_state_
cnte
r
[
TensorState
.
READY_FOR_REDUCE
]
==
4
assert
my_chunk
.
can_reduce
assert
my_chunk
.
can_reduce
my_chunk
.
reduce
()
my_chunk
.
reduce
()
assert
my_chunk
.
tensor
s
_state_
monito
r
[
TensorState
.
HOLD
]
==
4
assert
my_chunk
.
tensor_state_
cnte
r
[
TensorState
.
HOLD
]
==
4
if
keep_gathered
is
False
:
if
keep_gathered
is
False
:
assert
my_chunk
.
cuda_shard
.
size
(
0
)
==
1024
//
world_size
assert
my_chunk
.
cuda_shard
.
size
(
0
)
==
1024
//
world_size
assert
my_chunk
.
device_type
==
'cuda'
assert
my_chunk
.
device_type
==
'cuda'
assert
my_chunk
.
can_move
assert
my_chunk
.
can_move
else
:
else
:
assert
my_chunk
.
c
hunk_total
.
size
(
0
)
==
1024
assert
my_chunk
.
c
uda_global_chunk
.
size
(
0
)
==
1024
assert
my_chunk
.
device_type
==
'cuda'
assert
my_chunk
.
device_type
==
'cuda'
assert
not
my_chunk
.
can_move
assert
not
my_chunk
.
can_move
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment