Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e1922ea4
"examples/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "d43a671ad6db17b59a926b992106385356a86b7b"
Unverified
Commit
e1922ea4
authored
Jun 02, 2022
by
ver217
Committed by
GitHub
Jun 02, 2022
Browse files
[zero] add chunk size search for chunk manager (#1052)
parent
2c42b230
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
0 deletions
+38
-0
colossalai/tensor/chunk.py
colossalai/tensor/chunk.py
+38
-0
No files found.
colossalai/tensor/chunk.py
View file @
e1922ea4
...
@@ -268,3 +268,41 @@ class ChunkManager:
...
@@ -268,3 +268,41 @@ class ChunkManager:
for
i
,
chunk
in
enumerate
(
group
):
for
i
,
chunk
in
enumerate
(
group
):
msg
+=
f
'[
{
i
}
]
{
chunk
}
\n
'
msg
+=
f
'[
{
i
}
]
{
chunk
}
\n
'
return
msg
return
msg
@
staticmethod
def
get_chunk_util
(
chunk_size
:
int
,
params_numel
:
List
[
int
])
->
float
:
assert
len
(
params_numel
)
>
0
total_size
=
0
total_utilized_size
=
0
cur_chunk_utilized_size
=
0
for
size
in
params_numel
:
assert
chunk_size
>=
size
total_utilized_size
+=
size
if
total_size
==
0
or
cur_chunk_utilized_size
+
size
>
chunk_size
:
total_size
+=
chunk_size
cur_chunk_utilized_size
=
0
cur_chunk_utilized_size
+=
size
return
total_utilized_size
/
total_size
@
staticmethod
def
search_chunk_size
(
module
:
torch
.
nn
.
Module
,
search_range
:
int
,
n_grids
:
int
,
min_chunk_size
:
Optional
[
int
]
=
None
)
->
int
:
assert
search_range
%
n_grids
==
0
# TODO(ver217): sort params and filter unused ones
params_numel
=
[
p
.
numel
()
for
p
in
module
.
parameters
()]
max_param_numel
=
max
(
params_numel
)
if
min_chunk_size
is
not
None
:
assert
min_chunk_size
>=
max_param_numel
else
:
min_chunk_size
=
max_param_numel
step_size
=
search_range
//
n_grids
max_chunk_util
=
-
1
best_chunk_size
=
-
1
for
chunk_size
in
range
(
min_chunk_size
,
min_chunk_size
+
search_range
+
1
,
step_size
):
chunk_util
=
ChunkManager
.
get_chunk_util
(
chunk_size
,
params_numel
)
if
chunk_util
>
max_chunk_util
:
max_chunk_util
=
chunk_util
best_chunk_size
=
chunk_size
return
best_chunk_size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment