Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0d212183
Unverified
Commit
0d212183
authored
Aug 10, 2022
by
HELSON
Committed by
GitHub
Aug 10, 2022
Browse files
[zero] add has_inf_or_nan in AgChunk; enhance the unit test of AgChunk (#1426)
parent
33f0744d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
11 deletions
+90
-11
colossalai/gemini/ag_chunk.py
colossalai/gemini/ag_chunk.py
+43
-6
tests/test_gemini/chunk/test_agchunk.py
tests/test_gemini/chunk/test_agchunk.py
+47
-5
No files found.
colossalai/gemini/ag_chunk.py
View file @
0d212183
import
torch
import
torch.distributed
as
dist
from
typing
import
Optional
,
Dict
from
typing
import
Optional
,
Dict
,
List
from
colossalai.utils
import
get_current_device
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
...
...
@@ -45,10 +45,11 @@ class AgChunk:
self
.
shard_size
=
chunk_size
//
self
.
pg_size
self
.
shard_begin
=
self
.
shard_size
*
self
.
pg_rank
self
.
shard_end
=
self
.
shard_begin
+
self
.
shard_size
self
.
valid_end
=
self
.
shard_size
self
.
dtype
=
dtype
device
=
init_device
or
get_current_device
()
self
.
chunk_temp
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
self
.
chunk_temp
=
torch
.
zeros
(
chunk_size
,
dtype
=
dtype
,
device
=
device
)
# keep all zero
self
.
chunk_total
=
None
# we force chunk_total located in CUDA
self
.
cuda_shard
=
None
# using two attributes for the better interpretation
self
.
cpu_shard
=
None
...
...
@@ -114,7 +115,7 @@ class AgChunk:
if
self
.
chunk_temp
is
not
None
:
return
self
.
chunk_temp
.
device
.
type
else
:
if
self
.
chunk_total
is
not
None
:
if
self
.
is_gathered
:
return
'cuda'
elif
self
.
cuda_shard
is
not
None
:
return
'cuda'
...
...
@@ -153,6 +154,12 @@ class AgChunk:
# sanity check
assert
self
.
chunk_temp
is
not
None
# calculate the valid end for each shard
if
self
.
utilized_size
<=
self
.
shard_begin
:
self
.
valid_end
=
0
elif
self
.
utilized_size
<
self
.
shard_end
:
self
.
valid_end
=
self
.
utilized_size
-
self
.
shard_begin
if
self
.
chunk_temp
.
device
.
type
==
'cpu'
:
self
.
chunk_total
=
self
.
chunk_temp
.
to
(
get_current_device
())
else
:
...
...
@@ -257,7 +264,7 @@ class AgChunk:
self
.
shard_size
,
dtype
=
self
.
dtype
,
device
=
get_current_device
())
input_list
=
list
(
torch
.
chunk
(
self
.
chunk_total
,
chunks
=
self
.
pg_size
,
dim
=
0
))
dist
.
reduce_scatter
(
self
.
cuda_shard
,
input_list
,
self
.
torch_pg
)
dist
.
reduce_scatter
(
self
.
cuda_shard
,
input_list
,
group
=
self
.
torch_pg
)
free_storage
(
self
.
chunk_total
)
self
.
is_gathered
=
False
...
...
@@ -298,17 +305,38 @@ class AgChunk:
assert
self
.
is_gathered
tensor_info
=
self
.
tensors_info
[
tensor
]
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
flatten
())
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
copy_
(
data_slice
.
data
.
flatten
())
tensor
.
data
=
self
.
chunk_total
[
tensor_info
.
offset
:
tensor_info
.
end
].
view
(
tensor
.
shape
)
@
property
def
can_move
(
self
)
->
bool
:
return
not
self
.
is_gathered
@
property
def
can_release
(
self
)
->
bool
:
return
self
.
tensors_state_monitor
[
TensorState
.
HOLD
]
==
self
.
num_tensors
if
self
.
keep_gathered
:
return
False
else
:
return
self
.
tensors_state_monitor
[
TensorState
.
HOLD
]
+
\
self
.
tensors_state_monitor
[
TensorState
.
HOLD_AFTER_BWD
]
==
self
.
num_tensors
@
property
def
can_reduce
(
self
):
return
self
.
tensors_state_monitor
[
TensorState
.
READY_FOR_REDUCE
]
==
self
.
num_tensors
@
property
def
has_inf_or_nan
(
self
)
->
bool
:
"""
Check if the chunk has inf or nan values in CUDA.
"""
if
self
.
is_gathered
:
valid_tensor
=
self
.
chunk_total
[:
self
.
utilized_size
]
else
:
assert
self
.
cuda_shard
is
not
None
# only check in CUDA
valid_tensor
=
self
.
cuda_shard
[:
self
.
valid_end
]
return
torch
.
isinf
(
valid_tensor
).
any
().
item
()
|
torch
.
isnan
(
valid_tensor
).
any
().
item
()
def
__gather
(
self
):
if
not
self
.
is_gathered
:
# sanity check
...
...
@@ -375,6 +403,12 @@ class AgChunk:
if
prev_state
is
None
or
tensor_info
.
state
==
prev_state
:
self
.
__update_one_tensor_info
(
tensor_info
,
next_state
)
def
__hash__
(
self
)
->
int
:
return
hash
(
id
(
self
))
def
__eq__
(
self
,
__o
:
object
)
->
bool
:
return
self
is
__o
def
__repr__
(
self
,
detailed
:
bool
=
False
):
output
=
[
"AgChunk Information:
\n
"
,
...
...
@@ -413,3 +447,6 @@ class AgChunk:
output
.
append
(
"
\t\t
# of {}: {}
\n
"
.
format
(
st
,
self
.
tensors_state_monitor
[
st
]))
return
''
.
join
(
output
)
def
get_tensors
(
self
)
->
List
[
torch
.
Tensor
]:
return
list
(
self
.
tensors_info
.
keys
())
tests/test_gemini/chunk/test_agchunk.py
View file @
0d212183
...
...
@@ -2,16 +2,24 @@ import torch
import
colossalai
import
pytest
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
functools
import
partial
from
colossalai.testing
import
rerun_if_address_is_in_use
,
parameterize
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor
import
ColoParameter
from
colossalai.gemini
import
TensorState
from
colossalai.gemini.ag_chunk
import
AgChunk
def
dist_sum
(
x
):
temp
=
torch
.
tensor
([
x
],
device
=
get_current_device
())
dist
.
all_reduce
(
temp
)
return
temp
.
item
()
def
add_param
(
param_list
,
param_cp_list
,
*
args
,
**
kwargs
):
param
=
ColoParameter
(
torch
.
empty
(
*
args
,
**
kwargs
))
param
=
ColoParameter
(
torch
.
randn
(
*
args
,
**
kwargs
))
param_list
.
append
(
param
)
param_cp_list
.
append
(
param
.
clone
())
...
...
@@ -27,7 +35,7 @@ def check_euqal(param, param_cp):
@
parameterize
(
'init_device'
,
[
None
,
torch
.
device
(
'cpu'
)])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
@
parameterize
(
'pin_memory'
,
[
True
,
False
])
def
exam_chunk_
init
(
init_device
,
keep_gathered
,
pin_memory
):
def
exam_chunk_
basic
(
init_device
,
keep_gathered
,
pin_memory
):
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ColoProcessGroup
()
my_chunk
=
AgChunk
(
...
...
@@ -56,17 +64,51 @@ def exam_chunk_init(init_device, keep_gathered, pin_memory):
if
keep_gathered
is
False
:
assert
my_chunk
.
cpu_shard
.
size
(
0
)
==
1024
//
world_size
assert
my_chunk
.
device_type
==
'cpu'
assert
my_chunk
.
can_move
my_chunk
.
shard_move
(
get_current_device
())
else
:
assert
my_chunk
.
chunk_total
.
size
(
0
)
==
1024
assert
my_chunk
.
device_type
==
'cuda'
assert
not
my_chunk
.
can_move
my_chunk
.
access_chunk
()
assert
dist_sum
(
my_chunk
.
valid_end
)
==
my_chunk
.
utilized_size
flag
=
my_chunk
.
has_inf_or_nan
assert
not
flag
,
"has_inf_or_nan is {}"
.
format
(
flag
)
my_chunk
.
access_chunk
()
assert
my_chunk
.
device_type
==
'cuda'
for
param
,
param_cp
in
zip
(
param_list
,
param_cp_list
):
check_euqal
(
param
,
param_cp
)
assert
my_chunk
.
tensors_state_monitor
[
TensorState
.
HOLD
]
==
4
my_chunk
.
tensor_trans_state
(
param_list
[
0
],
TensorState
.
COMPUTE
)
assert
my_chunk
.
tensors_state_monitor
[
TensorState
.
HOLD
]
==
3
assert
my_chunk
.
tensors_state_monitor
[
TensorState
.
COMPUTE
]
==
1
assert
not
my_chunk
.
can_release
for
param
in
param_list
:
my_chunk
.
tensor_trans_state
(
param
,
TensorState
.
COMPUTE
)
my_chunk
.
tensor_trans_state
(
param
,
TensorState
.
READY_FOR_REDUCE
)
assert
my_chunk
.
tensors_state_monitor
[
TensorState
.
READY_FOR_REDUCE
]
==
4
assert
my_chunk
.
can_reduce
my_chunk
.
reduce
()
assert
my_chunk
.
tensors_state_monitor
[
TensorState
.
HOLD
]
==
4
if
keep_gathered
is
False
:
assert
my_chunk
.
cuda_shard
.
size
(
0
)
==
1024
//
world_size
assert
my_chunk
.
device_type
==
'cuda'
assert
my_chunk
.
can_move
else
:
assert
my_chunk
.
chunk_total
.
size
(
0
)
==
1024
assert
my_chunk
.
device_type
==
'cuda'
assert
not
my_chunk
.
can_move
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_chunk_
init
()
exam_chunk_
basic
()
@
pytest
.
mark
.
dist
...
...
@@ -78,4 +120,4 @@ def test_chunk_function(world_size):
if
__name__
==
'__main__'
:
test_chunk_function
(
2
)
test_chunk_function
(
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment