Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
56b8863b
"ubench/git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "a886a567a3bf9b8648b5346e07094432af6b5153"
Unverified
Commit
56b8863b
authored
Aug 02, 2022
by
ver217
Committed by
GitHub
Aug 02, 2022
Browse files
[zero] chunk manager allows filtering ex-large params (#1393)
parent
adf5054f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
5 deletions
+17
-5
colossalai/gemini/chunk_mgr.py
colossalai/gemini/chunk_mgr.py
+17
-5
No files found.
colossalai/gemini/chunk_mgr.py
View file @
56b8863b
import
torch
import
numpy
as
np
from
typing
import
Optional
,
Dict
,
Deque
,
Set
,
List
,
Tuple
,
Iterable
from
collections
import
deque
...
...
@@ -61,9 +62,6 @@ class ChunkManager:
if
isinstance
(
tensor
,
ColoTensor
):
assert
tensor
.
get_process_group
().
dp_process_group
()
==
self
.
process_group
.
dp_process_group
(
),
f
"Chunk Manager can only manage ColoTensor with the same DP process group"
if
self
.
chunk_size
is
not
None
and
tensor
.
numel
()
>
self
.
chunk_size
:
raise
ValueError
(
f
'Cannot create chunk, got tensor numel (
{
tensor
.
numel
()
}
) > chunk size (
{
self
.
chunk_size
}
)'
)
try
:
# append the tensor to the last chunk
self
.
chunk_groups
[
group_name
][
-
1
].
append
(
tensor
)
...
...
@@ -71,6 +69,9 @@ class ChunkManager:
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
if
self
.
chunk_size
is
not
None
and
tensor
.
numel
()
>
self
.
chunk_size
:
chunk_size
=
tensor
.
numel
()
else
:
chunk_size
=
self
.
chunk_size
or
tensor
.
numel
()
src_rank
=
self
.
_get_next_src_rank
(
group_name
)
chunk
=
Chunk
(
chunk_size
,
...
...
@@ -263,7 +264,8 @@ class ChunkManager:
def
search_chunk_size
(
module
:
torch
.
nn
.
Module
,
search_range
:
int
,
n_grids
:
int
,
min_chunk_size
:
Optional
[
int
]
=
None
)
->
int
:
min_chunk_size
:
Optional
[
int
]
=
None
,
filter_exlarge_params
:
bool
=
True
)
->
int
:
"""
Search for the chunk size for optimal chunk utilization.
...
...
@@ -278,6 +280,8 @@ class ChunkManager:
assert
search_range
%
n_grids
==
0
# TODO(ver217): sort params and filter unused ones
params_numel
=
[
p
.
numel
()
for
p
in
module
.
parameters
()]
if
filter_exlarge_params
:
params_numel
=
_filter_exlarge_params
(
params_numel
)
max_param_numel
=
max
(
params_numel
)
if
min_chunk_size
is
not
None
:
assert
min_chunk_size
>=
max_param_numel
...
...
@@ -330,3 +334,11 @@ class ChunkManager:
"""
assert
tensor
not
in
self
.
tensor_chunk_map
self
.
total_mem
[
tensor
.
device
.
type
]
+=
tensor
.
numel
()
*
tensor
.
element_size
()
def
_filter_exlarge_params
(
params_numel
:
List
[
int
])
->
List
[
int
]:
params_numel_arr
=
np
.
array
(
params_numel
)
std
=
np
.
std
(
params_numel_arr
)
mean
=
np
.
mean
(
params_numel_arr
)
upper_limit
=
mean
+
3
*
std
return
list
(
filter
(
lambda
x
:
x
<=
upper_limit
,
params_numel
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment