Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d329c294
Unverified
Commit
d329c294
authored
Apr 17, 2023
by
YH
Committed by
GitHub
Apr 17, 2023
Browse files
Add docstr for zero3 chunk search utils (#3572)
parent
9edeadfb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
3 deletions
+31
-3
colossalai/zero/gemini/chunk/search_utils.py
colossalai/zero/gemini/chunk/search_utils.py
+31
-3
No files found.
colossalai/zero/gemini/chunk/search_utils.py
View file @
d329c294
...
...
@@ -11,8 +11,13 @@ from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator
def
_filter_exlarge_params
(
model
:
nn
.
Module
,
size_dict
:
Dict
[
int
,
List
[
int
]])
->
None
:
"""
"""_filter_exlarge_params
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
Args:
model (nn.Module): the model.
size_dict (Dict[int, List[int]]): the size dict of parameters.
"""
agg_size_list
=
[]
for
key
in
size_dict
:
...
...
@@ -33,7 +38,16 @@ def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) ->
def
_get_unused_byte
(
size_list
:
List
[
int
],
chunk_size
:
int
)
->
int
:
"""Get unused byte for a certain chunk size.
"""_get_unused_byte
Get unused byte for a certain chunk size.
Args:
size_list (List[int]): the size list of parameters.
chunk_size (int): the chunk size.
Returns:
int: the unused byte.
"""
acc
=
0
left
=
0
...
...
@@ -45,7 +59,18 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return
left
+
acc
def
_tensor_numel
(
local_param
:
ColoParameter
,
strict_ddp_flag
:
bool
):
def
_tensor_numel
(
local_param
:
ColoParameter
,
strict_ddp_flag
:
bool
)
->
int
:
"""_tensor_numel
Get the number of elements of a tensor.
Args:
local_param (ColoParameter): The local parameter.
strict_ddp_flag (bool): whether to enable the strict ddp mode.
Returns:
int: the number of elements.
"""
if
strict_ddp_flag
and
type
(
local_param
)
is
ColoParameter
:
return
local_param
.
numel_global
()
else
:
...
...
@@ -61,6 +86,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
Args:
param_order (OrderedParamGenerator): the order of param be visied
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False.
Returns:
Dict[int, List[ColoParameter]]: a dict contains the classification results.
...
...
@@ -96,6 +122,8 @@ def search_chunk_configuration(
memstas
:
Optional
[
MemStats
]
=
None
)
->
Tuple
[
Dict
,
int
,
int
]:
"""search_chunk_configuration
Search the chunk configuration for a model.
Args:
model (nn.Module): torch module
search_range_mb (float): searching range in mega byte.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment