Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4fb3c52c
Unverified
Commit
4fb3c52c
authored
Aug 09, 2022
by
HELSON
Committed by
GitHub
Aug 09, 2022
Browse files
[zero] add unit test for AgChunk's append, close, access (#1423)
parent
c577ed01
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
133 additions
and
3 deletions
+133
-3
colossalai/gemini/ag_chunk.py
colossalai/gemini/ag_chunk.py
+52
-3
tests/test_gemini/chunk/test_agchunk.py
tests/test_gemini/chunk/test_agchunk.py
+81
-0
No files found.
colossalai/gemini/ag_chunk.py
View file @
4fb3c52c
...
@@ -36,7 +36,7 @@ class AgChunk:
...
@@ -36,7 +36,7 @@ class AgChunk:
self
.
utilized_size
=
0
self
.
utilized_size
=
0
# Here, we use torch process group,
# Here, we use torch process group,
# since ColoProcessGroup might get deprecated soon
# since ColoProcessGroup might get deprecated soon
self
.
torch_pg
=
process_group
.
dp_process_group
self
.
torch_pg
=
process_group
.
dp_process_group
()
self
.
pg_size
=
dist
.
get_world_size
(
self
.
torch_pg
)
self
.
pg_size
=
dist
.
get_world_size
(
self
.
torch_pg
)
self
.
pg_rank
=
dist
.
get_rank
(
self
.
torch_pg
)
self
.
pg_rank
=
dist
.
get_rank
(
self
.
torch_pg
)
...
@@ -69,6 +69,8 @@ class AgChunk:
...
@@ -69,6 +69,8 @@ class AgChunk:
# some chunks can keep gathered all the time
# some chunks can keep gathered all the time
# so their computation patterns are the same as that of the parameters in DDP
# so their computation patterns are the same as that of the parameters in DDP
self
.
keep_gathered
=
keep_gathered
self
.
keep_gathered
=
keep_gathered
if
self
.
keep_gathered
:
pin_memory
=
False
# since this chunk is gathered, it doesn't need to pin
# if pin_memory is True, we allocate a piece of CPU pin-memory
# if pin_memory is True, we allocate a piece of CPU pin-memory
# for it all the time
# for it all the time
...
@@ -134,7 +136,7 @@ class AgChunk:
...
@@ -134,7 +136,7 @@ class AgChunk:
if
new_utilized_size
>
self
.
chunk_size
:
if
new_utilized_size
>
self
.
chunk_size
:
raise
ChunkFullError
raise
ChunkFullError
self
.
chunk_temp
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
flatten
())
self
.
chunk_temp
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
data
.
flatten
())
assert
type
(
self
.
chunk_temp
)
==
torch
.
Tensor
,
"copy_tensor_to_chunk_slice must use a torch tensor"
assert
type
(
self
.
chunk_temp
)
==
torch
.
Tensor
,
"copy_tensor_to_chunk_slice must use a torch tensor"
tensor
.
data
=
self
.
chunk_temp
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
tensor
.
data
=
self
.
chunk_temp
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
...
@@ -145,7 +147,7 @@ class AgChunk:
...
@@ -145,7 +147,7 @@ class AgChunk:
self
.
tensors_state_monitor
[
tensor_state
]
+=
1
self
.
tensors_state_monitor
[
tensor_state
]
+=
1
self
.
utilized_size
=
new_utilized_size
self
.
utilized_size
=
new_utilized_size
def
close_chunk
(
self
,
shard_dev
:
torch
.
device
):
def
close_chunk
(
self
,
shard_dev
:
Optional
[
torch
.
device
]
=
None
):
"""Close the chunk. Any tensor can't be appended to a closed chunk.
"""Close the chunk. Any tensor can't be appended to a closed chunk.
"""
"""
# sanity check
# sanity check
...
@@ -159,6 +161,14 @@ class AgChunk:
...
@@ -159,6 +161,14 @@ class AgChunk:
self
.
__scatter
()
self
.
__scatter
()
if
self
.
keep_gathered
:
if
shard_dev
is
None
:
shard_dev
=
get_current_device
()
else
:
assert
shard_dev
.
type
==
'cuda'
elif
shard_dev
is
None
:
shard_dev
=
torch
.
device
(
'cpu'
)
if
self
.
pin_memory
or
shard_dev
.
type
==
'cpu'
:
if
self
.
pin_memory
or
shard_dev
.
type
==
'cpu'
:
self
.
cpu_shard
=
torch
.
empty
(
self
.
shard_size
,
self
.
cpu_shard
=
torch
.
empty
(
self
.
shard_size
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -364,3 +374,42 @@ class AgChunk:
...
@@ -364,3 +374,42 @@ class AgChunk:
for
tensor_info
in
self
.
tensors_info
.
values
():
for
tensor_info
in
self
.
tensors_info
.
values
():
if
prev_state
is
None
or
tensor_info
.
state
==
prev_state
:
if
prev_state
is
None
or
tensor_info
.
state
==
prev_state
:
self
.
__update_one_tensor_info
(
tensor_info
,
next_state
)
self
.
__update_one_tensor_info
(
tensor_info
,
next_state
)
def
__repr__
(
self
,
detailed
:
bool
=
False
):
output
=
[
"AgChunk Information:
\n
"
,
"
\t
chunk size: {}, chunk dtype: {}, process group size: {}
\n
"
.
format
(
self
.
chunk_size
,
self
.
dtype
,
self
.
pg_size
),
"
\t
# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}
\n
"
.
format
(
self
.
num_tensors
,
self
.
utilized_size
,
self
.
utilized_size
/
self
.
chunk_size
)
]
def
print_tensor
(
tensor
,
prefix
=
''
):
output
.
append
(
"{}shape: {}, dtype: {}, device: {}
\n
"
.
format
(
prefix
,
tensor
.
shape
,
tensor
.
dtype
,
tensor
.
device
))
if
self
.
chunk_temp
is
not
None
:
output
.
append
(
"
\t
chunk temp:
\n
"
)
print_tensor
(
tensor
=
self
.
chunk_temp
,
prefix
=
'
\t\t
'
)
if
self
.
chunk_total
is
not
None
and
self
.
chunk_total
.
storage
().
size
()
>
0
:
output
.
append
(
"
\t
chunk total:
\n
"
)
print_tensor
(
tensor
=
self
.
chunk_total
,
prefix
=
'
\t\t
'
)
if
self
.
cuda_shard
is
not
None
:
output
.
append
(
"
\t
cuda shard:
\n
"
)
print_tensor
(
tensor
=
self
.
cuda_shard
,
prefix
=
'
\t\t
'
)
if
self
.
cpu_shard
is
not
None
:
output
.
append
(
"
\t
cpu shard:
\n
"
)
print_tensor
(
tensor
=
self
.
cpu_shard
,
prefix
=
'
\t\t
'
)
memory_info
=
self
.
memory_usage
output
.
append
(
"
\t
memory usage: cuda {}, cpu {}
\n
"
.
format
(
memory_info
[
'cuda'
],
memory_info
[
'cpu'
]))
if
detailed
:
output
.
append
(
"
\t
tensor state monitor:
\n
"
)
for
st
in
TensorState
:
output
.
append
(
"
\t\t
# of {}: {}
\n
"
.
format
(
st
,
self
.
tensors_state_monitor
[
st
]))
return
''
.
join
(
output
)
tests/test_gemini/chunk/test_agchunk.py
0 → 100644
View file @
4fb3c52c
import
torch
import
colossalai
import
pytest
import
torch.multiprocessing
as
mp
from
functools
import
partial
from
colossalai.testing
import
rerun_if_address_is_in_use
,
parameterize
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor
import
ColoParameter
from
colossalai.gemini.ag_chunk
import
AgChunk
def
add_param
(
param_list
,
param_cp_list
,
*
args
,
**
kwargs
):
param
=
ColoParameter
(
torch
.
empty
(
*
args
,
**
kwargs
))
param_list
.
append
(
param
)
param_cp_list
.
append
(
param
.
clone
())
def
check_euqal
(
param
,
param_cp
):
if
param
.
device
!=
param_cp
.
device
:
temp
=
param
.
data
.
to
(
param_cp
.
device
)
else
:
temp
=
param
.
data
return
torch
.
equal
(
temp
,
param_cp
.
data
)
@
parameterize
(
'init_device'
,
[
None
,
torch
.
device
(
'cpu'
)])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
@
parameterize
(
'pin_memory'
,
[
True
,
False
])
def
exam_chunk_init
(
init_device
,
keep_gathered
,
pin_memory
):
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ColoProcessGroup
()
my_chunk
=
AgChunk
(
chunk_size
=
1024
,
process_group
=
pg
,
dtype
=
torch
.
float32
,
init_device
=
init_device
,
keep_gathered
=
keep_gathered
,
pin_memory
=
pin_memory
)
param_list
=
[]
param_cp_list
=
[]
add_param
(
param_list
,
param_cp_list
,
8
,
8
,
8
,
device
=
'cuda'
)
add_param
(
param_list
,
param_cp_list
,
4
,
4
)
add_param
(
param_list
,
param_cp_list
,
4
,
8
,
2
,
device
=
'cuda'
)
add_param
(
param_list
,
param_cp_list
,
1
,
1
,
5
)
for
param
in
param_list
:
my_chunk
.
append_tensor
(
param
)
assert
my_chunk
.
utilized_size
==
597
for
param
,
param_cp
in
zip
(
param_list
,
param_cp_list
):
check_euqal
(
param
,
param_cp
)
my_chunk
.
close_chunk
()
if
keep_gathered
is
False
:
assert
my_chunk
.
cpu_shard
.
size
(
0
)
==
1024
//
world_size
my_chunk
.
shard_move
(
get_current_device
())
my_chunk
.
access_chunk
()
for
param
,
param_cp
in
zip
(
param_list
,
param_cp_list
):
check_euqal
(
param
,
param_cp
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_chunk_init
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
])
@
rerun_if_address_is_in_use
()
def
test_chunk_function
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_chunk_function
(
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment