Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
21dc54e0
Unverified
Commit
21dc54e0
authored
Mar 14, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 14, 2022
Browse files
[zero] memtracer to record cuda memory usage of model data and overall system (#395)
parent
a37bf1bc
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
239 additions
and
76 deletions
+239
-76
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+15
-1
colossalai/utils/memory_tracer/allocator.py
colossalai/utils/memory_tracer/allocator.py
+9
-50
colossalai/utils/memory_tracer/async_memtracer.py
colossalai/utils/memory_tracer/async_memtracer.py
+2
-2
colossalai/utils/memory_tracer/commons.py
colossalai/utils/memory_tracer/commons.py
+11
-0
colossalai/utils/memory_tracer/memstats_collector.py
colossalai/utils/memory_tracer/memstats_collector.py
+81
-0
colossalai/utils/memory_tracer/model_data_memtracer.py
colossalai/utils/memory_tracer/model_data_memtracer.py
+34
-0
colossalai/utils/memory_tracer/test_memstats_collector.py
colossalai/utils/memory_tracer/test_memstats_collector.py
+43
-0
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+5
-4
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+25
-4
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+1
-1
tests/test_utils/test_activation_checkpointing.py
tests/test_utils/test_activation_checkpointing.py
+1
-0
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+5
-8
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+6
-2
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+1
-4
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
21dc54e0
...
@@ -4,6 +4,9 @@ from colossalai.utils import get_current_device
...
@@ -4,6 +4,9 @@ from colossalai.utils import get_current_device
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
._base_ophook
import
BaseOpHook
from
._base_ophook
import
BaseOpHook
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
ModelDataTracer
from
typing
import
Optional
@
OPHOOKS
.
register_module
@
OPHOOKS
.
register_module
...
@@ -12,14 +15,17 @@ class ZeroHook(BaseOpHook):
...
@@ -12,14 +15,17 @@ class ZeroHook(BaseOpHook):
A hook to process sharded param for ZeRO method.
A hook to process sharded param for ZeRO method.
"""
"""
def
__init__
(
self
,
shard_strategy
:
BaseShardStrategy
):
def
__init__
(
self
,
shard_strategy
:
BaseShardStrategy
,
memstarts_collector
:
Optional
[
MemStatsCollector
]
):
super
().
__init__
()
super
().
__init__
()
self
.
shard_strategy
=
shard_strategy
self
.
shard_strategy
=
shard_strategy
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self
.
computing_device
=
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)
self
.
computing_device
=
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)
self
.
_memstarts_collector
=
memstarts_collector
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
tensor_list
=
[]
tensor_list
=
[]
global_model_data_tracer
=
ModelDataTracer
()
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
data
)
tensor_list
.
append
(
param
.
col_attr
.
data
)
...
@@ -27,8 +33,12 @@ class ZeroHook(BaseOpHook):
...
@@ -27,8 +33,12 @@ class ZeroHook(BaseOpHook):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
global_model_data_tracer
.
add_tensor
(
param
.
col_attr
.
data
.
payload
)
param
.
data
=
param
.
col_attr
.
data
.
payload
param
.
data
=
param
.
col_attr
.
data
.
payload
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
tensor_list
=
[]
tensor_list
=
[]
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
...
@@ -40,6 +50,7 @@ class ZeroHook(BaseOpHook):
...
@@ -40,6 +50,7 @@ class ZeroHook(BaseOpHook):
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
tensor_list
=
[]
tensor_list
=
[]
global_model_data_tracer
=
ModelDataTracer
()
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
data
)
tensor_list
.
append
(
param
.
col_attr
.
data
)
...
@@ -47,6 +58,7 @@ class ZeroHook(BaseOpHook):
...
@@ -47,6 +58,7 @@ class ZeroHook(BaseOpHook):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
global_model_data_tracer
.
add_tensor
(
param
.
col_attr
.
data
.
payload
)
param
.
data
=
param
.
col_attr
.
data
.
payload
param
.
data
=
param
.
col_attr
.
data
.
payload
# Store local accumulated grad shard
# Store local accumulated grad shard
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
...
@@ -60,6 +72,8 @@ class ZeroHook(BaseOpHook):
...
@@ -60,6 +72,8 @@ class ZeroHook(BaseOpHook):
# The grad here must be locally computed full grad in this backward pass
# The grad here must be locally computed full grad in this backward pass
assert
param
.
grad
.
shape
==
param
.
col_attr
.
data
.
origin_shape
assert
param
.
grad
.
shape
==
param
.
col_attr
.
data
.
origin_shape
param
.
col_attr
.
bwd_count
+=
1
param
.
col_attr
.
bwd_count
+=
1
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
tensor_list
=
[]
tensor_list
=
[]
...
...
colossalai/utils/memory_tracer/allocator.py
View file @
21dc54e0
import
torch
import
torch
from
colossalai.utils.commons.singleton_meta
import
SingletonMeta
from
colossalai.utils.memory_tracer.model_data_memtracer
import
ModelDataTracer
from
colossalai.zero.sharded_param
import
ShardedTensor
from
typing
import
Union
def
col_move_to_cpu
(
t
:
torch
.
Tensor
):
assert
isinstance
(
t
,
torch
.
Tensor
)
if
t
.
device
.
type
==
'cpu'
:
return
def
col_tensor_mem_usage
(
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
])
->
int
:
ModelDataTracer
().
delete_tensor
(
t
)
if
isinstance
(
t
,
ShardedTensor
):
t
.
data
=
t
.
data
.
cpu
()
target
=
t
.
payload
else
:
target
=
t
return
target
.
numel
()
*
target
.
element_size
()
class
ModelDataTracer
(
metaclass
=
SingletonMeta
):
def
col_modeldata_allocate
(
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""
A singleton to trace model data usage during runtime.
"""
def
__init__
(
self
)
->
None
:
self
.
_cpu_usage
=
0
self
.
_cuda_usage
=
0
def
trace_tensor
(
self
,
t
:
torch
.
Tensor
):
mem_use
=
col_tensor_mem_usage
(
t
)
if
t
.
device
.
type
==
'cpu'
:
self
.
_cpu_usage
+=
mem_use
elif
t
.
device
.
type
==
'cuda'
:
self
.
_cuda_usage
+=
mem_use
else
:
raise
RuntimeError
def
detach_tensor
(
self
,
t
:
torch
.
Tensor
):
mem_use
=
col_tensor_mem_usage
(
t
)
if
t
.
device
.
type
==
'cpu'
:
self
.
_cpu_usage
-=
mem_use
elif
t
.
device
.
type
==
'cuda'
:
self
.
_cuda_usage
-=
mem_use
else
:
raise
RuntimeError
@
property
def
cpu_usage
(
self
):
return
self
.
_cpu_usage
@
property
def
cuda_usage
(
self
):
return
self
.
_cuda_usage
GLOBAL_MODEL_DATA_TRACER
=
ModelDataTracer
()
def
col_allocate_payload
(
device
:
torch
.
device
)
->
torch
.
Tensor
:
pass
pass
def
col_
release_payload
(
t
:
torch
.
Tensor
):
def
col_
modeldata_release
(
t
:
torch
.
Tensor
):
pass
pass
colossalai/utils/memory_tracer/async_memtracer.py
View file @
21dc54e0
...
@@ -6,7 +6,7 @@ from colossalai.utils import get_current_device
...
@@ -6,7 +6,7 @@ from colossalai.utils import get_current_device
import
torch
import
torch
def
_
get_cuda_memory_used
(
device
:
torch
.
device
)
->
int
:
def
get_cuda_memory_used
(
device
:
torch
.
device
)
->
int
:
"""
"""
Get the free memory info of device.
Get the free memory info of device.
:param device: device id
:param device: device id
...
@@ -87,7 +87,7 @@ class AsyncMemoryMonitor:
...
@@ -87,7 +87,7 @@ class AsyncMemoryMonitor:
while
self
.
keep_measuring
:
while
self
.
keep_measuring
:
max_usage
=
max
(
max_usage
=
max
(
max_usage
,
max_usage
,
_
get_cuda_memory_used
(
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)),
get_cuda_memory_used
(
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)),
)
)
sleep
(
self
.
interval
)
sleep
(
self
.
interval
)
return
max_usage
return
max_usage
...
...
colossalai/utils/memory_tracer/commons.py
0 → 100644
View file @
21dc54e0
from
colossalai.zero.sharded_param
import
ShardedTensor
from
typing
import
Union
import
torch
def
col_tensor_mem_usage
(
t
:
Union
[
torch
.
Tensor
,
ShardedTensor
])
->
int
:
if
isinstance
(
t
,
ShardedTensor
):
target
=
t
.
payload
else
:
target
=
t
return
target
.
numel
()
*
target
.
element_size
()
colossalai/utils/memory_tracer/memstats_collector.py
0 → 100644
View file @
21dc54e0
from
colossalai.utils.memory_tracer.model_data_memtracer
import
ModelDataTracer
from
.async_memtracer
import
get_cuda_memory_used
from
colossalai.utils
import
get_current_device
import
torch
class
SamplingCounter
:
def
__init__
(
self
)
->
None
:
self
.
_samplint_cnt
=
0
def
advance
(
self
):
self
.
_samplint_cnt
+=
1
@
property
def
sampling_cnt
(
self
):
return
self
.
_samplint_cnt
def
reset
(
self
):
self
.
_samplint_cnt
=
0
class
MemStatsCollector
:
def
__init__
(
self
)
->
None
:
"""
Collecting Memory Statistics.
It has two phases.
1. Collection Phase: collect memory usage statistics
2. Runtime Phase: do not collect statistics.
"""
self
.
_sampling_cnter
=
SamplingCounter
()
self
.
_model_data_cuda
=
[]
self
.
_overall_cuda
=
[]
# TODO(jiaruifang) Now no cpu mem stats collecting
self
.
_model_data_cpu
=
[]
self
.
_overall_cpu
=
[]
self
.
_start_flag
=
False
def
start_collection
(
self
):
self
.
_start_flag
=
True
def
finish_collection
(
self
):
self
.
_start_flag
=
False
def
sample_memstats
(
self
)
->
None
:
"""
Sampling memory statistics.
Record the current model data CUDA memory usage as well as system CUDA memory usage.
"""
if
self
.
_start_flag
:
sampling_cnt
=
self
.
_sampling_cnter
.
sampling_cnt
assert
sampling_cnt
==
len
(
self
.
_overall_cuda
)
self
.
_model_data_cuda
.
append
(
ModelDataTracer
().
cuda_usage
)
self
.
_overall_cuda
.
append
(
get_cuda_memory_used
(
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)))
self
.
_sampling_cnter
.
advance
()
def
fetch_memstats
(
self
)
->
(
int
,
int
):
"""
returns cuda usage of model data and overall cuda usage.
"""
sampling_cnt
=
self
.
_sampling_cnter
.
sampling_cnt
if
len
(
self
.
_model_data_cuda
)
<
sampling_cnt
:
raise
RuntimeError
return
(
self
.
_model_data_cuda
[
sampling_cnt
],
self
.
_overall_cuda
[
sampling_cnt
])
def
reset_sampling_cnter
(
self
)
->
None
:
self
.
_sampling_cnter
.
reset
()
def
clear
(
self
)
->
None
:
self
.
_model_data_cuda
=
[]
self
.
_overall_cuda
=
[]
self
.
_model_data_cpu
=
[]
self
.
_overall_cpu
=
[]
self
.
_start_flag
=
False
self
.
_sampling_cnter
.
reset
()
colossalai/utils/memory_tracer/model_data_memtracer.py
0 → 100644
View file @
21dc54e0
from
colossalai.utils.commons.singleton_meta
import
SingletonMeta
from
colossalai.utils.memory_tracer.commons
import
col_tensor_mem_usage
import
torch
class
ModelDataTracer
(
metaclass
=
SingletonMeta
):
"""
A singleton to trace model data usage during runtime.
We have to trigger our API (trace_tensor, detach_tensor) when do model-data memory operation,
including allocation, releasing and moving.
NOTE() now the class only trace cuda memory usage
"""
def
__init__
(
self
)
->
None
:
self
.
_cuda_usage
=
0
def
add_tensor
(
self
,
t
:
torch
.
Tensor
):
assert
isinstance
(
t
,
torch
.
Tensor
),
f
"ModelDataTracer add_tensor() should accept a torch.Tensor"
mem_use
=
col_tensor_mem_usage
(
t
)
self
.
_cuda_usage
+=
mem_use
def
delete_tensor
(
self
,
t
:
torch
.
Tensor
):
assert
isinstance
(
t
,
torch
.
Tensor
),
f
"ModelDataTracer delete_tensor() should accept a torch.Tensor"
mem_use
=
col_tensor_mem_usage
(
t
)
self
.
_cuda_usage
-=
mem_use
@
property
def
cpu_usage
(
self
):
return
self
.
_cpu_usage
@
property
def
cuda_usage
(
self
):
return
self
.
_cuda_usage
colossalai/utils/memory_tracer/test_memstats_collector.py
0 → 100644
View file @
21dc54e0
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
ModelDataTracer
import
torch
def
test_mem_collector
():
collector
=
MemStatsCollector
()
collector
.
start_collection
()
a
=
torch
.
randn
(
10
).
cuda
()
# sampling at time 0
collector
.
sample_memstats
()
m_a
=
torch
.
randn
(
10
).
cuda
()
ModelDataTracer
().
add_tensor
(
m_a
)
b
=
torch
.
randn
(
10
).
cuda
()
# sampling at time 1
collector
.
sample_memstats
()
a
=
b
# sampling at time 2
collector
.
sample_memstats
()
collector
.
finish_collection
()
collector
.
reset
()
# do nothing after collection, just advance sampling cnter
collector
.
sample_memstats
()
collector
.
sample_memstats
()
cuda_use
,
overall_use
=
collector
.
fetch_memstats
()
print
(
cuda_use
,
overall_use
)
print
(
collector
.
_model_data_cuda
)
print
(
collector
.
_overall_cuda
)
if
__name__
==
'__main__'
:
test_mem_collector
()
colossalai/zero/init_ctx/init_context.py
View file @
21dc54e0
...
@@ -3,10 +3,11 @@ import functools
...
@@ -3,10 +3,11 @@ import functools
import
torch
import
torch
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_tracer.model_data_memtracer
import
ModelDataTracer
# Inserts _post_init_method at the end of init method
# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
# for all sub classes of torch.nn.Module
class
InsertPostInitMethodToModuleSubClasses
(
object
):
class
InsertPostInitMethodToModuleSubClasses
(
object
):
...
@@ -152,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -152,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
self
.
shard_param
:
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
col_attr
.
_data_sharded_tensor
])
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
col_attr
.
_data_sharded_tensor
])
GLOBAL_MODEL_DATA_TRACER
.
trace
_tensor
(
param
.
col_attr
.
_data_sharded_tensor
.
payload
)
ModelDataTracer
().
add
_tensor
(
param
.
col_attr
.
_data_sharded_tensor
.
payload
)
if
param
.
col_attr
.
grad
and
self
.
shard_grad
:
if
param
.
col_attr
.
grad
and
self
.
shard_grad
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
col_attr
.
_grad_sharded_tensor
])
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
col_attr
.
_grad_sharded_tensor
])
GLOBAL_MODEL_DATA_TRACER
.
trace
_tensor
(
param
.
col_attr
.
_grad_sharded_tensor
.
payload
)
ModelDataTracer
().
add
_tensor
(
param
.
col_attr
.
_grad_sharded_tensor
.
payload
)
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
21dc54e0
...
@@ -17,7 +17,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
...
@@ -17,7 +17,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.allocator
import
col_move_to_cpu
from
._zero3_utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
from
._zero3_utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
get_gradient_predivide_factor
)
get_gradient_predivide_factor
)
...
@@ -33,7 +34,8 @@ class ShardedModelV2(nn.Module):
...
@@ -33,7 +34,8 @@ class ShardedModelV2(nn.Module):
fp32_reduce_scatter
:
bool
=
False
,
fp32_reduce_scatter
:
bool
=
False
,
offload_config
:
Optional
[
dict
]
=
None
,
offload_config
:
Optional
[
dict
]
=
None
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
shard_param
:
bool
=
True
):
shard_param
:
bool
=
True
,
use_memory_tracer
:
bool
=
False
):
r
"""
r
"""
A demo to reconfigure zero1 shared_model.
A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States.
Currently do not consider the Optimizer States.
...
@@ -59,8 +61,16 @@ class ShardedModelV2(nn.Module):
...
@@ -59,8 +61,16 @@ class ShardedModelV2(nn.Module):
if
self
.
shard_param
:
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
# Init Memory Statistics Collector
self
.
_use_memory_tracer
=
use_memory_tracer
if
self
.
_use_memory_tracer
:
self
.
_memstats_collector
=
MemStatsCollector
()
else
:
self
.
_memstats_collector
=
None
self
.
_iter_cnter
=
0
# Register hooks
# Register hooks
register_ophooks_recursively
(
self
.
module
,
[
ZeroHook
(
self
.
shard_strategy
)])
register_ophooks_recursively
(
self
.
module
,
[
ZeroHook
(
self
.
shard_strategy
,
self
.
_memstats_collector
)])
self
.
param_hook_mgr
=
BaseParamHookMgr
(
list
(
self
.
module
.
parameters
()))
self
.
param_hook_mgr
=
BaseParamHookMgr
(
list
(
self
.
module
.
parameters
()))
self
.
param_hook_mgr
.
register_backward_hooks
(
self
.
_grad_post_backward_hook
)
self
.
param_hook_mgr
.
register_backward_hooks
(
self
.
_grad_post_backward_hook
)
...
@@ -84,6 +94,9 @@ class ShardedModelV2(nn.Module):
...
@@ -84,6 +94,9 @@ class ShardedModelV2(nn.Module):
return
self
.
_cpu_offload
return
self
.
_cpu_offload
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
if
self
.
_iter_cnter
==
0
and
self
.
_memstats_collector
:
# the opeartion will affect the flag in ZeroHook
self
.
_memstats_collector
.
start_collection
()
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
return
outputs
return
outputs
...
@@ -98,6 +111,12 @@ class ShardedModelV2(nn.Module):
...
@@ -98,6 +111,12 @@ class ShardedModelV2(nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_final_backward_hook
(
self
)
->
None
:
def
_final_backward_hook
(
self
)
->
None
:
if
self
.
_iter_cnter
==
0
and
self
.
_memstats_collector
:
self
.
_memstats_collector
.
finish_collection
()
if
self
.
_memstats_collector
:
self
.
_memstats_collector
.
reset_sampling_cnter
()
self
.
_iter_cnter
+=
1
if
self
.
_require_backward_grad_sync
:
if
self
.
_require_backward_grad_sync
:
# Flush any unreduced buckets in the post_backward stream.
# Flush any unreduced buckets in the post_backward stream.
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
...
@@ -185,8 +204,10 @@ class ShardedModelV2(nn.Module):
...
@@ -185,8 +204,10 @@ class ShardedModelV2(nn.Module):
reduced_grad
.
data
=
cast_tensor_to_fp32
(
reduced_grad
.
data
)
reduced_grad
.
data
=
cast_tensor_to_fp32
(
reduced_grad
.
data
)
# Maybe offload
# Maybe offload
# TODO() optimize GPU->CPU bandwidth utilization
if
self
.
_cpu_offload
:
if
self
.
_cpu_offload
:
reduced_grad
.
data
=
reduced_grad
.
data
.
cpu
()
col_move_to_cpu
(
reduced_grad
)
# reduced_grad.data = reduced_grad.data.cpu()
if
param
.
col_attr
.
grad
is
None
:
if
param
.
col_attr
.
grad
is
None
:
param
.
col_attr
.
grad
=
reduced_grad
.
data
param
.
col_attr
.
grad
=
reduced_grad
.
data
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
21dc54e0
...
@@ -143,7 +143,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -143,7 +143,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We have to use `copy_payload` instead of `reset_payload`
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.data is fp16
# Since p.data is fp32 and p.col_attr.data is fp16
# TODO() optimize this line
# TODO() optimize this line
CPU (fp32) -> GPU (fp16)
p
.
col_attr
.
data
.
copy_payload
(
p
.
data
)
p
.
col_attr
.
data
.
copy_payload
(
p
.
data
)
if
not
is_param_sharded
:
if
not
is_param_sharded
:
...
...
tests/test_utils/test_activation_checkpointing.py
View file @
21dc54e0
...
@@ -56,6 +56,7 @@ def test_activation_checkpointing(cpu_offload):
...
@@ -56,6 +56,7 @@ def test_activation_checkpointing(cpu_offload):
assert
torch
.
all
(
data
.
grad
==
data_
.
grad
),
'Gradient of the input does not match'
assert
torch
.
all
(
data
.
grad
==
data_
.
grad
),
'Gradient of the input does not match'
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# as seed manager is singleton
# as seed manager is singleton
# if we don't reset seeds here,
# if we don't reset seeds here,
# other tests will fail if running together with this test
# other tests will fail if running together with this test
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
21dc54e0
...
@@ -9,12 +9,12 @@ import torch
...
@@ -9,12 +9,12 @@ import torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
from
common
import
CONFIG
from
colossalai.utils.memory_tracer.model_data_memtracer
import
ModelDataTracer
def
run_dist
(
rank
,
world_size
,
port
,
init_device
,
shard_strategy
):
def
run_dist
(
rank
,
world_size
,
port
,
init_device
,
shard_strategy
):
...
@@ -37,13 +37,10 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
...
@@ -37,13 +37,10 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
assert
param
.
col_attr
.
data
.
payload
.
device
.
type
==
init_device
.
type
,
\
assert
param
.
col_attr
.
data
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
data
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
f
'
{
param
.
col_attr
.
data
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
print
(
f
'cpu usgae
{
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
}
'
)
print
(
f
'cuda usgae
{
ModelDataTracer
().
cuda_usage
}
'
)
print
(
f
'cuda usgae
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
'
)
print
(
f
'numel
{
model_numel_tensor
}
'
)
print
(
f
'numel
{
model_numel_tensor
}
'
)
if
init_device
.
type
==
'cuda'
:
if
init_device
.
type
==
'cuda'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
assert
(
ModelDataTracer
().
cuda_usage
>
0
)
elif
init_device
.
type
==
'cpu'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
>
0
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
@@ -60,5 +57,5 @@ def test_zero_init_context(world_size, init_device, shard_strategy):
...
@@ -60,5 +57,5 @@ def test_zero_init_context(world_size, init_device, shard_strategy):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_zero_init_context
(
2
,
torch
.
device
(
'cpu'
),
TensorShardStrategy
)
#
test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
test_zero_init_context
(
2
,
torch
.
device
(
f
'c
uda:
{
get_current_device
()
}
'
),
TensorShardStrategy
)
test_zero_init_context
(
4
,
torch
.
device
(
'c
pu'
),
Bucket
TensorShardStrategy
)
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
21dc54e0
...
@@ -18,6 +18,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
...
@@ -18,6 +18,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
,
shard_strategy
):
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
,
shard_strategy
):
...
@@ -33,12 +34,12 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
...
@@ -33,12 +34,12 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
if
use_zero_init_ctx
:
if
use_zero_init_ctx
:
with
ZeroInitContext
(
convert_fp16
=
True
,
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
torch
.
device
(
'cpu'
),
target_device
=
torch
.
device
(
f
'cpu
:0
'
),
shard_strategy
=
shard_strategy
,
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
):
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
):
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
,
use_memory_tracer
=
True
)
model
=
model_builder
(
checkpoint
=
True
).
half
()
model
=
model_builder
(
checkpoint
=
True
).
half
()
col_model_deepcopy
(
zero_model
,
model
)
col_model_deepcopy
(
zero_model
,
model
)
...
@@ -59,6 +60,9 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
...
@@ -59,6 +60,9 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
print
(
'overall cuda '
,
zero_model
.
_memstats_collector
.
_overall_cuda
)
print
(
'model cuda '
,
zero_model
.
_memstats_collector
.
_model_data_cuda
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
21dc54e0
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
copy
import
copy
from
functools
import
partial
from
functools
import
partial
...
@@ -82,4 +79,4 @@ def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy):
...
@@ -82,4 +79,4 @@ def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_sharded_optim_v2
(
world_size
=
2
,
cpu_offload
=
True
,
shard_strategy
=
TensorShardStrategy
)
test_sharded_optim_v2
(
world_size
=
2
,
cpu_offload
=
True
,
shard_strategy
=
TensorShardStrategy
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment