Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8afc001f
Unverified
Commit
8afc001f
authored
Dec 11, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 11, 2022
Browse files
[Gemini] chunk init use OrderedParamGenerator (#2110)
parent
63fbba3c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
32 additions
and
11 deletions
+32
-11
colossalai/gemini/chunk/__init__.py
colossalai/gemini/chunk/__init__.py
+2
-0
colossalai/gemini/chunk/search_utils.py
colossalai/gemini/chunk/search_utils.py
+9
-4
colossalai/gemini/memory_tracer/__init__.py
colossalai/gemini/memory_tracer/__init__.py
+2
-2
colossalai/gemini/memory_tracer/memory_stats.py
colossalai/gemini/memory_tracer/memory_stats.py
+2
-2
colossalai/gemini/memory_tracer/param_runtime_order.py
colossalai/gemini/memory_tracer/param_runtime_order.py
+16
-2
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
+1
-1
No files found.
colossalai/gemini/chunk/__init__.py
View file @
8afc001f
...
...
@@ -2,3 +2,5 @@ from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState
from
.manager
import
ChunkManager
from
.search_utils
import
classify_params_by_dp_degree
,
search_chunk_configuration
from
.utils
import
init_chunk_manager
__all__
=
[
'Chunk'
,
'ChunkManager'
,
'classify_params_by_dp_degree'
,
'search_chunk_configuration'
,
'init_chunk_manager'
]
colossalai/gemini/chunk/search_utils.py
View file @
8afc001f
...
...
@@ -4,6 +4,7 @@ from typing import Dict, List, Tuple
import
numpy
as
np
import
torch.nn
as
nn
from
colossalai.gemini.memory_tracer
import
OrderedParamGenerator
from
colossalai.tensor
import
ColoParameter
...
...
@@ -40,20 +41,20 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return
left
+
acc
def
classify_params_by_dp_degree
(
mo
de
l
:
nn
.
Module
)
->
Dict
[
int
,
List
[
ColoParameter
]]:
def
classify_params_by_dp_degree
(
param_or
de
r
:
OrderedParamGenerator
)
->
Dict
[
int
,
List
[
ColoParameter
]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
Args:
mo
de
l
(
nn.Module): model
param_or
de
r
(
OrderedParamGenerator): the order of param be visied
Returns:
Dict[int, List[ColoParameter]]: a dict contains the classification results.
The keys are dp_degrees and the values are parameters.
"""
params_dict
:
Dict
[
int
,
List
[
ColoParameter
]]
=
dict
()
for
param
in
model
.
parameters
():
for
param
in
param_order
.
generate
():
assert
isinstance
(
param
,
ColoParameter
),
"please init model in the ColoInitContext"
if
not
in_ddp
(
param
):
continue
...
...
@@ -85,11 +86,15 @@ def search_chunk_configuration(
Tuple[Dict, int]: chunk config and its memory chunk waste in byte.
"""
param_order
=
OrderedParamGenerator
()
for
p
in
model
.
parameters
():
param_order
.
append
(
p
)
search_range_byte
=
round
(
search_range_mb
*
1024
**
2
)
min_chunk_size_byte
=
round
(
min_chunk_size_mb
*
1024
**
2
)
assert
search_range_byte
>=
0
params_dict
=
classify_params_by_dp_degree
(
mo
de
l
)
params_dict
=
classify_params_by_dp_degree
(
param_or
de
r
)
config_dict
:
Dict
[
int
,
Dict
]
=
dict
()
size_dict
:
Dict
[
int
,
List
[
int
]]
=
dict
()
...
...
colossalai/gemini/memory_tracer/__init__.py
View file @
8afc001f
from
.param_runtime_order
import
ParamRuntimeOrde
r
# isort:skip
from
.param_runtime_order
import
OrderedParamGenerato
r
# isort:skip
from
.memory_stats
import
MemStats
# isort:skip
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
# isort:skip
from
.memstats_collector
import
MemStatsCollector
# isort:skip
...
...
@@ -7,5 +7,5 @@ from .static_memstats_collector import StaticMemStatsCollector # isort:skip
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'ChunkMemStatsCollector'
,
'StaticMemStatsCollector'
,
'MemStats'
,
'
ParamRuntimeOrde
r'
'StaticMemStatsCollector'
,
'MemStats'
,
'
OrderedParamGenerato
r'
]
colossalai/gemini/memory_tracer/memory_stats.py
View file @
8afc001f
from
typing
import
Any
,
Dict
,
List
from
colossalai.gemini.memory_tracer
import
ParamRuntimeOrde
r
from
colossalai.gemini.memory_tracer
import
OrderedParamGenerato
r
class
MemStats
(
object
):
...
...
@@ -21,7 +21,7 @@ class MemStats(object):
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_param_runtime_order
=
ParamRuntimeOrde
r
()
self
.
_param_runtime_order
=
OrderedParamGenerato
r
()
def
append_overall_data
(
self
,
device_type
:
str
,
val
:
float
):
if
device_type
==
'cuda'
:
...
...
colossalai/gemini/memory_tracer/param_runtime_order.py
View file @
8afc001f
from
abc
import
ABC
import
torch
class
ParamRuntimeOrder
(
object
):
"""ParamRuntimeOrder
class
ParamGenerator
(
ABC
):
def
append
(
self
,
param
:
torch
.
nn
.
Parameter
):
pass
def
generate
(
self
):
pass
def
clear
(
self
):
pass
class
OrderedParamGenerator
(
ParamGenerator
):
"""OrderedParamGenerator
Contain the order of parameters visited during runtime.
"""
...
...
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
View file @
8afc001f
import
torch.nn
from
colossalai.gemini.memory_tracer
import
MemStats
,
ParamRuntimeOrder
from
colossalai.gemini.memory_tracer
import
MemStats
from
colossalai.gemini.ophooks.runtime_mem_tracer_hook
import
GradMemStats
,
GradMemTracerHook
,
ParamMemTracerHook
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.tensor.param_op_hook
import
ColoParamOpHookManager
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment