Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
56bb412e
"runtime/rust/src/vscode:/vscode.git/clone" did not exist on "9d6643b7a59220fc4f3ef599c002241dd0bf9965"
Unverified
Commit
56bb412e
authored
Mar 15, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 15, 2022
Browse files
[polish] use GLOBAL_MODEL_DATA_TRACER (#417)
parent
23ba3fc4
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
25 additions
and
25 deletions
+25
-25
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+3
-5
colossalai/utils/memory_tracer/allocator.py
colossalai/utils/memory_tracer/allocator.py
+2
-2
colossalai/utils/memory_tracer/memstats_collector.py
colossalai/utils/memory_tracer/memstats_collector.py
+2
-2
colossalai/utils/memory_tracer/model_data_memtracer.py
colossalai/utils/memory_tracer/model_data_memtracer.py
+6
-4
colossalai/utils/memory_tracer/test_memstats_collector.py
colossalai/utils/memory_tracer/test_memstats_collector.py
+3
-3
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+3
-3
tests/test_amp/test_naive_fp16.py
tests/test_amp/test_naive_fp16.py
+3
-3
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+3
-3
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
56bb412e
...
...
@@ -5,7 +5,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from
._base_ophook
import
BaseOpHook
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
ModelDataTracer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
typing
import
Optional
...
...
@@ -25,7 +25,6 @@ class ZeroHook(BaseOpHook):
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
tensor_list
=
[]
global_model_data_tracer
=
ModelDataTracer
()
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
data
)
...
...
@@ -33,7 +32,7 @@ class ZeroHook(BaseOpHook):
for
param
in
module
.
parameters
():
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
global_model_data_tracer
.
add_tensor
(
param
.
col_attr
.
data
.
payload
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
data
.
payload
)
param
.
data
=
param
.
col_attr
.
data
.
payload
if
self
.
_memstarts_collector
:
...
...
@@ -50,7 +49,6 @@ class ZeroHook(BaseOpHook):
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
tensor_list
=
[]
global_model_data_tracer
=
ModelDataTracer
()
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
data
)
...
...
@@ -58,7 +56,7 @@ class ZeroHook(BaseOpHook):
for
param
in
module
.
parameters
():
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
global_model_data_tracer
.
add_tensor
(
param
.
col_attr
.
data
.
payload
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
data
.
payload
)
param
.
data
=
param
.
col_attr
.
data
.
payload
# Store local accumulated grad shard
if
param
.
grad
is
not
None
:
...
...
colossalai/utils/memory_tracer/allocator.py
View file @
56bb412e
import
torch
from
colossalai.utils.memory_tracer.model_data_memtracer
import
ModelDataTracer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
def
col_move_to_cpu
(
t
:
torch
.
Tensor
):
...
...
@@ -7,7 +7,7 @@ def col_move_to_cpu(t: torch.Tensor):
if
t
.
device
.
type
==
'cpu'
:
return
ModelDataTracer
()
.
delete_tensor
(
t
)
GLOBAL_MODEL_DATA_TRACER
.
delete_tensor
(
t
)
t
.
data
=
t
.
data
.
cpu
()
...
...
colossalai/utils/memory_tracer/memstats_collector.py
View file @
56bb412e
from
colossalai.utils.memory_tracer.model_data_memtracer
import
ModelDataTracer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
.async_memtracer
import
get_cuda_memory_used
from
colossalai.utils
import
get_current_device
...
...
@@ -54,7 +54,7 @@ class MemStatsCollector:
if
self
.
_start_flag
:
sampling_cnt
=
self
.
_sampling_cnter
.
sampling_cnt
assert
sampling_cnt
==
len
(
self
.
_overall_cuda
)
self
.
_model_data_cuda
.
append
(
ModelDataTracer
()
.
cuda_usage
)
self
.
_model_data_cuda
.
append
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
)
self
.
_overall_cuda
.
append
(
get_cuda_memory_used
(
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)))
self
.
_sampling_cnter
.
advance
()
...
...
colossalai/utils/memory_tracer/model_data_memtracer.py
View file @
56bb412e
...
...
@@ -5,10 +5,9 @@ import torch
class
ModelDataTracer
(
metaclass
=
SingletonMeta
):
"""
A singleton to trace model data usage during runtime.
We have to trigger our API (trace_tensor, detach_tensor) when do model-data memory operation,
including allocation, releasing and moving.
A tracer singleton to trace model data usage during runtime.
The tracer is designed to trace the memory layout change during model-data tensors allocation, releasing, and moving.
To achieve this goal, the developers have to call `ModelDataTracer` in the corresponding code explicitly.
NOTE() now the class only trace cuda memory usage
"""
...
...
@@ -32,3 +31,6 @@ class ModelDataTracer(metaclass=SingletonMeta):
@
property
def
cuda_usage
(
self
):
return
self
.
_cuda_usage
GLOBAL_MODEL_DATA_TRACER
=
ModelDataTracer
()
colossalai/utils/memory_tracer/test_memstats_collector.py
View file @
56bb412e
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
ModelDataTracer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
import
torch
...
...
@@ -14,7 +14,7 @@ def test_mem_collector():
collector
.
sample_memstats
()
m_a
=
torch
.
randn
(
10
).
cuda
()
ModelDataTracer
()
.
add_tensor
(
m_a
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
m_a
)
b
=
torch
.
randn
(
10
).
cuda
()
# sampling at time 1
...
...
@@ -26,7 +26,7 @@ def test_mem_collector():
collector
.
sample_memstats
()
collector
.
finish_collection
()
collector
.
reset
()
collector
.
reset
_sampling_cnter
()
# do nothing after collection, just advance sampling cnter
collector
.
sample_memstats
()
...
...
colossalai/zero/init_ctx/init_context.py
View file @
56bb412e
...
...
@@ -3,7 +3,7 @@ import functools
import
torch
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.utils.memory_tracer.model_data_memtracer
import
ModelDataTracer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
# Inserts _post_init_method at the end of init method
...
...
@@ -153,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
col_attr
.
_data_sharded_tensor
])
ModelDataTracer
()
.
add_tensor
(
param
.
col_attr
.
_data_sharded_tensor
.
payload
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
_data_sharded_tensor
.
payload
)
if
param
.
col_attr
.
grad
and
self
.
shard_grad
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
col_attr
.
_grad_sharded_tensor
])
ModelDataTracer
()
.
add_tensor
(
param
.
col_attr
.
_grad_sharded_tensor
.
payload
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
_grad_sharded_tensor
.
payload
)
tests/test_amp/test_naive_fp16.py
View file @
56bb412e
...
...
@@ -26,15 +26,15 @@ def run_naive_amp():
test_models
=
[
'repeated_computed_layers'
,
'nested_model'
]
for
test_name
in
test_models
:
get_component_func
=
non_distributed_component_funcs
.
get_callable
(
test_name
)
model_builder
,
train_dataloader
,
_
,
optim_
builder
,
_
=
get_component_func
()
model_builder
,
train_dataloader
,
_
,
optim_
class
,
_
=
get_component_func
()
# create model
amp_model
=
model_builder
(
checkpoint
=
True
).
cuda
()
torch_model
=
copy
.
deepcopy
(
amp_model
)
# create optimizer
amp_optimizer
=
optim_
builder
(
amp_model
)
torch_optimizer
=
optim_
builder
(
torch_model
)
amp_optimizer
=
optim_
class
(
amp_model
.
parameters
(),
lr
=
1e-3
)
torch_optimizer
=
optim_
class
(
torch_model
.
parameters
(),
lr
=
1e-3
)
# inject naive amp
amp_config
=
dict
(
initial_scale
=
1
)
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
56bb412e
...
...
@@ -14,7 +14,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
from
colossalai.utils.memory_tracer.model_data_memtracer
import
ModelDataTracer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
def
run_dist
(
rank
,
world_size
,
port
,
init_device
,
shard_strategy
):
...
...
@@ -37,10 +37,10 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
assert
param
.
col_attr
.
data
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
data
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
print
(
f
'cuda usgae
{
ModelDataTracer
()
.
cuda_usage
}
'
)
print
(
f
'cuda usgae
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
'
)
print
(
f
'numel
{
model_numel_tensor
}
'
)
if
init_device
.
type
==
'cuda'
:
assert
(
ModelDataTracer
()
.
cuda_usage
>
0
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
@
pytest
.
mark
.
dist
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment