Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0aab5230
Unverified
Commit
0aab5230
authored
Apr 03, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 03, 2022
Browse files
[hotfix] fix a bug in model data stats tracing (#655)
parent
ade05a5d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
12 deletions
+15
-12
colossalai/utils/memory_tracer/memstats_collector.py
colossalai/utils/memory_tracer/memstats_collector.py
+9
-3
colossalai/utils/memory_tracer/model_data_memtracer.py
colossalai/utils/memory_tracer/model_data_memtracer.py
+1
-1
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+4
-8
tests/test_moe/test_moe_zero_init.py
tests/test_moe/test_moe_zero_init.py
+1
-0
No files found.
colossalai/utils/memory_tracer/memstats_collector.py
View file @
0aab5230
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_device_memory_used
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_tracer.async_memtracer
import
AsyncMemoryMonitor
import
torch
import
time
from
typing
import
List
...
...
@@ -37,6 +37,7 @@ class MemStatsCollector:
def
__init__
(
self
)
->
None
:
self
.
_sampling_cnter
=
SamplingCounter
()
self
.
_mem_monitor
=
AsyncMemoryMonitor
()
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
...
...
@@ -101,6 +102,7 @@ class MemStatsCollector:
def
start_collection
(
self
):
self
.
_start_flag
=
True
self
.
_mem_monitor
.
start
()
def
finish_collection
(
self
):
self
.
_start_flag
=
False
...
...
@@ -115,17 +117,20 @@ class MemStatsCollector:
sampling_cnt
=
self
.
_sampling_cnter
.
sampling_cnt
assert
sampling_cnt
==
len
(
self
.
_overall_cuda_list
)
self
.
_model_data_cuda_list
.
append
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
)
self
.
_overall_cuda_list
.
append
(
colo_device_memory_used
(
get_current_device
()
))
self
.
_overall_cuda_list
.
append
(
self
.
_mem_monitor
.
finish
(
))
self
.
_model_data_cpu_list
.
append
(
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
)
# FIXME() cpu sys used should also return from self._mem_monitor()
self
.
_overall_cpu_list
.
append
(
colo_device_memory_used
(
torch
.
device
(
f
'cpu'
)))
self
.
_sampling_time
.
append
(
time
.
time
())
self
.
_mem_monitor
.
start
()
self
.
_sampling_cnter
.
advance
()
def
reset_sampling_cnter
(
self
)
->
None
:
self
.
_sampling_cnter
.
reset
()
self
.
_mem_monitor
.
finish
()
def
clear
(
self
)
->
None
:
self
.
_model_data_cuda_list
=
[]
...
...
@@ -136,3 +141,4 @@ class MemStatsCollector:
self
.
_start_flag
=
False
self
.
_sampling_cnter
.
reset
()
self
.
_mem_monitor
.
finish
()
\ No newline at end of file
colossalai/utils/memory_tracer/model_data_memtracer.py
View file @
0aab5230
...
...
@@ -33,7 +33,7 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
def
_get_tensor_mem_use
(
t
:
Optional
[
torch
.
Tensor
]):
if
t
is
None
:
return
return
0
,
0
assert
isinstance
(
t
,
torch
.
Tensor
)
_cpu_mem_usage
,
_cuda_mem_usage
=
0
,
0
if
t
.
device
.
type
==
'cpu'
:
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
0aab5230
...
...
@@ -139,10 +139,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if
self
.
_use_memory_tracer
:
GLOBAL_MODEL_DATA_TRACER
.
register_optimizer
(
self
)
self
.
_use_memory_tracer
=
self
.
model
.
use_memory_tracer
if
self
.
_use_memory_tracer
:
GLOBAL_MODEL_DATA_TRACER
.
register_optimizer
(
self
)
def
get_memory_usage
(
self
)
->
Tuple
[
int
,
int
]:
""" Get the memory usage of the optimizer. Including master_params (param fp32),
momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``)
...
...
@@ -186,7 +182,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
_zero_grad
(
recover_data
=
True
)
return
self
.
_p
repare_data
()
self
.
_p
oint_param_fp16_to_master_param
()
self
.
_logger
.
debug
(
f
"Before step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
...
...
@@ -197,7 +193,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
_logger
.
debug
(
f
"After step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
self
.
_
write_back_data
()
self
.
_
copy_master_param_to_param_fp16
()
return
ret
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
...
...
@@ -319,7 +315,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Set p.data to empty tensor, in case of memory leaking
p
.
colo_attr
.
remove_torch_payload
()
def
_p
repare_data
(
self
):
def
_p
oint_param_fp16_to_master_param
(
self
):
# assign master param pointers to p.data.
# We will not trigger data copy here.
for
group
in
self
.
optim
.
param_groups
:
...
...
@@ -329,7 +325,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Now p.data is sharded
# So optimizer states are sharded naturally
def
_
write_back_data
(
self
):
def
_
copy_master_param_to_param_fp16
(
self
):
# Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
...
...
tests/test_moe/test_moe_zero_init.py
View file @
0aab5230
...
...
@@ -91,6 +91,7 @@ def _run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
@
pytest
.
mark
.
skip
(
"Under development"
)
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_moe_zero_init
(
world_size
):
run_func
=
partial
(
_run_dist
,
world_size
=
world_size
,
port
=
free_port
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment