Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0d212183
Unverified
Commit
0d212183
authored
Aug 10, 2022
by
HELSON
Committed by
GitHub
Aug 10, 2022
Browse files
[zero] add has_inf_or_nan in AgChunk; enhance the unit test of AgChunk (#1426)
parent
33f0744d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
11 deletions
+90
-11
colossalai/gemini/ag_chunk.py
colossalai/gemini/ag_chunk.py
+43
-6
tests/test_gemini/chunk/test_agchunk.py
tests/test_gemini/chunk/test_agchunk.py
+47
-5
No files found.
colossalai/gemini/ag_chunk.py
View file @
0d212183
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
typing
import
Optional
,
Dict
from
typing
import
Optional
,
Dict
,
List
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
...
@@ -45,10 +45,11 @@ class AgChunk:
...
@@ -45,10 +45,11 @@ class AgChunk:
self
.
shard_size
=
chunk_size
//
self
.
pg_size
self
.
shard_size
=
chunk_size
//
self
.
pg_size
self
.
shard_begin
=
self
.
shard_size
*
self
.
pg_rank
self
.
shard_begin
=
self
.
shard_size
*
self
.
pg_rank
self
.
shard_end
=
self
.
shard_begin
+
self
.
shard_size
self
.
shard_end
=
self
.
shard_begin
+
self
.
shard_size
self
.
valid_end
=
self
.
shard_size
self
.
dtype
=
dtype
self
.
dtype
=
dtype
device
=
init_device
or
get_current_device
()
device
=
init_device
or
get_current_device
()
self
.
chunk_temp
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
self
.
chunk_temp
=
torch
.
zeros
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
# keep all zero
self
.
chunk_total
=
None
# we force chunk_total located in CUDA
self
.
chunk_total
=
None
# we force chunk_total located in CUDA
self
.
cuda_shard
=
None
# using two attributes for the better interpretation
self
.
cuda_shard
=
None
# using two attributes for the better interpretation
self
.
cpu_shard
=
None
self
.
cpu_shard
=
None
...
@@ -114,7 +115,7 @@ class AgChunk:
...
@@ -114,7 +115,7 @@ class AgChunk:
if
self
.
chunk_temp
is
not
None
:
if
self
.
chunk_temp
is
not
None
:
return
self
.
chunk_temp
.
device
.
type
return
self
.
chunk_temp
.
device
.
type
else
:
else
:
if
self
.
chunk_total
is
not
None
:
if
self
.
is_gathered
:
return
'cuda'
return
'cuda'
elif
self
.
cuda_shard
is
not
None
:
elif
self
.
cuda_shard
is
not
None
:
return
'cuda'
return
'cuda'
...
@@ -153,6 +154,12 @@ class AgChunk:
...
@@ -153,6 +154,12 @@ class AgChunk:
# sanity check
# sanity check
assert
self
.
chunk_temp
is
not
None
assert
self
.
chunk_temp
is
not
None
# calculate the valid end for each shard
if
self
.
utilized_size
<=
self
.
shard_begin
:
self
.
valid_end
=
0
elif
self
.
utilized_size
<
self
.
shard_end
:
self
.
valid_end
=
self
.
utilized_size
-
self
.
shard_begin
if
self
.
chunk_temp
.
device
.
type
==
'cpu'
:
if
self
.
chunk_temp
.
device
.
type
==
'cpu'
:
self
.
chunk_total
=
self
.
chunk_temp
.
to
(
get_current_device
())
self
.
chunk_total
=
self
.
chunk_temp
.
to
(
get_current_device
())
else
:
else
:
...
@@ -257,7 +264,7 @@ class AgChunk:
...
@@ -257,7 +264,7 @@ class AgChunk:
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
get_current_device
())
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
get_current_device
())
input_list
=
list
(
torch
.
chunk
(
self
.
chunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
input_list
=
list
(
torch
.
chunk
(
self
.
chunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
reduce_scatter
(
self
.
cuda_shard
,
input_list
,
self
.
torch_pg
)
dist
.
reduce_scatter
(
self
.
cuda_shard
,
input_list
,
group
=
self
.
torch_pg
)
free_storage
(
self
.
chunk_total
)
free_storage
(
self
.
chunk_total
)
self
.
is_gathered
=
False
self
.
is_gathered
=
False
...
@@ -298,17 +305,38 @@ class AgChunk:
...
@@ -298,17 +305,38 @@ class AgChunk:
assert
self
.
is_gathered
assert
self
.
is_gathered
tensor_info
=
self
.
tensors_info
[
tensor
]
tensor_info
=
self
.
tensors_info
[
tensor
]
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
flatten
())
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
data
.
flatten
())
tensor
.
data
=
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
tensor
.
data
=
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
@
property
def
can_move
(
self
)
->
bool
:
return
not
self
.
is_gathered
@
property
@
property
def
can_release
(
self
)
->
bool
:
def
can_release
(
self
)
->
bool
:
return
self
.
tensors_state_monitor
[
TensorState
.
HOLD
]
==
self
.
num_tensors
if
self
.
keep_gathered
:
return
False
else
:
return
self
.
tensors_state_monitor
[
TensorState
.
HOLD
]
+
\
self
.
tensors_state_monitor
[
TensorState
.
HOLD_AFTER_BWD
]
==
self
.
num_tensors
@
property
@
property
def
can_reduce
(
self
):
def
can_reduce
(
self
):
return
self
.
tensors_state_monitor
[
TensorState
.
READY_FOR_REDUCE
]
==
self
.
num_tensors
return
self
.
tensors_state_monitor
[
TensorState
.
READY_FOR_REDUCE
]
==
self
.
num_tensors
@
property
def
has_inf_or_nan
(
self
)
->
bool
:
"""
Check if the chunk has inf or nan values in CUDA.
"""
if
self
.
is_gathered
:
valid_tensor
=
self
.
chunk_total
[:
self
.
utilized_size
]
else
:
assert
self
.
cuda_shard
is
not
None
# only check in CUDA
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
return
torch
.
isinf
(
valid_tensor
).
any
().
item
()
|
torch
.
isnan
(
valid_tensor
).
any
().
item
()
def
__gather
(
self
):
def
__gather
(
self
):
if
not
self
.
is_gathered
:
if
not
self
.
is_gathered
:
# sanity check
# sanity check
...
@@ -375,6 +403,12 @@ class AgChunk:
...
@@ -375,6 +403,12 @@ class AgChunk:
if
prev_state
is
None
or
tensor_info
.
state
==
prev_state
:
if
prev_state
is
None
or
tensor_info
.
state
==
prev_state
:
self
.
__update_one_tensor_info
(
tensor_info
,
next_state
)
self
.
__update_one_tensor_info
(
tensor_info
,
next_state
)
def
__hash__
(
self
)
->
int
:
return
hash
(
id
(
self
))
def
__eq__
(
self
,
__o
:
object
)
->
bool
:
return
self
is
__o
def
__repr__
(
self
,
detailed
:
bool
=
False
):
def
__repr__
(
self
,
detailed
:
bool
=
False
):
output
=
[
output
=
[
"AgChunk Information:
\n
"
,
"AgChunk Information:
\n
"
,
...
@@ -413,3 +447,6 @@ class AgChunk:
...
@@ -413,3 +447,6 @@ class AgChunk:
output
.
append
(
"
\t\t
# of {}: {}
\n
"
.
format
(
st
,
self
.
tensors_state_monitor
[
st
]))
output
.
append
(
"
\t\t
# of {}: {}
\n
"
.
format
(
st
,
self
.
tensors_state_monitor
[
st
]))
return
''
.
join
(
output
)
return
''
.
join
(
output
)
def
get_tensors
(
self
)
->
List
[
torch
.
Tensor
]:
return
list
(
self
.
tensors_info
.
keys
())
tests/test_gemini/chunk/test_agchunk.py
View file @
0d212183
...
@@ -2,16 +2,24 @@ import torch
...
@@ -2,16 +2,24 @@ import torch
import
colossalai
import
colossalai
import
pytest
import
pytest
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
functools
import
partial
from
functools
import
partial
from
colossalai.testing
import
rerun_if_address_is_in_use
,
parameterize
from
colossalai.testing
import
rerun_if_address_is_in_use
,
parameterize
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor
import
ColoParameter
from
colossalai.tensor
import
ColoParameter
from
colossalai.gemini
import
TensorState
from
colossalai.gemini.ag_chunk
import
AgChunk
from
colossalai.gemini.ag_chunk
import
AgChunk
def
dist_sum
(
x
):
temp
=
torch
.
tensor
([
x
],
device
=
get_current_device
())
dist
.
all_reduce
(
temp
)
return
temp
.
item
()
def
add_param
(
param_list
,
param_cp_list
,
*
args
,
**
kwargs
):
def
add_param
(
param_list
,
param_cp_list
,
*
args
,
**
kwargs
):
param
=
ColoParameter
(
torch
.
empty
(
*
args
,
**
kwargs
))
param
=
ColoParameter
(
torch
.
randn
(
*
args
,
**
kwargs
))
param_list
.
append
(
param
)
param_list
.
append
(
param
)
param_cp_list
.
append
(
param
.
clone
())
param_cp_list
.
append
(
param
.
clone
())
...
@@ -27,7 +35,7 @@ def check_euqal(param, param_cp):
...
@@ -27,7 +35,7 @@ def check_euqal(param, param_cp):
@
parameterize
(
'init_device'
,
[
None
,
torch
.
device
(
'cpu'
)])
@
parameterize
(
'init_device'
,
[
None
,
torch
.
device
(
'cpu'
)])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
@
parameterize
(
'pin_memory'
,
[
True
,
False
])
@
parameterize
(
'pin_memory'
,
[
True
,
False
])
def
exam_chunk_
init
(
init_device
,
keep_gathered
,
pin_memory
):
def
exam_chunk_
basic
(
init_device
,
keep_gathered
,
pin_memory
):
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ColoProcessGroup
()
pg
=
ColoProcessGroup
()
my_chunk
=
AgChunk
(
my_chunk
=
AgChunk
(
...
@@ -56,17 +64,51 @@ def exam_chunk_init(init_device, keep_gathered, pin_memory):
...
@@ -56,17 +64,51 @@ def exam_chunk_init(init_device, keep_gathered, pin_memory):
if
keep_gathered
is
False
:
if
keep_gathered
is
False
:
assert
my_chunk
.
cpu_shard
.
size
(
0
)
==
1024
//
world_size
assert
my_chunk
.
cpu_shard
.
size
(
0
)
==
1024
//
world_size
assert
my_chunk
.
device_type
==
'cpu'
assert
my_chunk
.
can_move
my_chunk
.
shard_move
(
get_current_device
())
my_chunk
.
shard_move
(
get_current_device
())
else
:
assert
my_chunk
.
chunk_total
.
size
(
0
)
==
1024
assert
my_chunk
.
device_type
==
'cuda'
assert
not
my_chunk
.
can_move
my_chunk
.
access_chunk
()
assert
dist_sum
(
my_chunk
.
valid_end
)
==
my_chunk
.
utilized_size
flag
=
my_chunk
.
has_inf_or_nan
assert
not
flag
,
"has_inf_or_nan is {}"
.
format
(
flag
)
my_chunk
.
access_chunk
()
assert
my_chunk
.
device_type
==
'cuda'
for
param
,
param_cp
in
zip
(
param_list
,
param_cp_list
):
for
param
,
param_cp
in
zip
(
param_list
,
param_cp_list
):
check_euqal
(
param
,
param_cp
)
check_euqal
(
param
,
param_cp
)
assert
my_chunk
.
tensors_state_monitor
[
TensorState
.
HOLD
]
==
4
my_chunk
.
tensor_trans_state
(
param_list
[
0
],
TensorState
.
COMPUTE
)
assert
my_chunk
.
tensors_state_monitor
[
TensorState
.
HOLD
]
==
3
assert
my_chunk
.
tensors_state_monitor
[
TensorState
.
COMPUTE
]
==
1
assert
not
my_chunk
.
can_release
for
param
in
param_list
:
my_chunk
.
tensor_trans_state
(
param
,
TensorState
.
COMPUTE
)
my_chunk
.
tensor_trans_state
(
param
,
TensorState
.
READY_FOR_REDUCE
)
assert
my_chunk
.
tensors_state_monitor
[
TensorState
.
READY_FOR_REDUCE
]
==
4
assert
my_chunk
.
can_reduce
my_chunk
.
reduce
()
assert
my_chunk
.
tensors_state_monitor
[
TensorState
.
HOLD
]
==
4
if
keep_gathered
is
False
:
assert
my_chunk
.
cuda_shard
.
size
(
0
)
==
1024
//
world_size
assert
my_chunk
.
device_type
==
'cuda'
assert
my_chunk
.
can_move
else
:
assert
my_chunk
.
chunk_total
.
size
(
0
)
==
1024
assert
my_chunk
.
device_type
==
'cuda'
assert
not
my_chunk
.
can_move
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_chunk_
init
()
exam_chunk_
basic
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
@@ -78,4 +120,4 @@ def test_chunk_function(world_size):
...
@@ -78,4 +120,4 @@ def test_chunk_function(world_size):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_chunk_function
(
2
)
test_chunk_function
(
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment