Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
84c6700b
Unverified
Commit
84c6700b
authored
Apr 14, 2022
by
HELSON
Committed by
GitHub
Apr 14, 2022
Browse files
[zero] refactor memstats_collector (#746)
parent
b8899e09
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
178 additions
and
110 deletions
+178
-110
colossalai/engine/ophooks/__init__.py
colossalai/engine/ophooks/__init__.py
+1
-0
colossalai/utils/memory_tracer/memstats_collector.py
colossalai/utils/memory_tracer/memstats_collector.py
+47
-43
colossalai/utils/memory_tracer/model_data_memtracer.py
colossalai/utils/memory_tracer/model_data_memtracer.py
+4
-0
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+0
-1
colossalai/zero/sharded_param/tensor_utils.py
colossalai/zero/sharded_param/tensor_utils.py
+1
-1
colossalai/zero/utils/stateful_tensor_mgr.py
colossalai/zero/utils/stateful_tensor_mgr.py
+0
-4
colossalai/zero/utils/tensor_placement_policy.py
colossalai/zero/utils/tensor_placement_policy.py
+2
-2
colossalai/zero/utils/zero_hook.py
colossalai/zero/utils/zero_hook.py
+30
-49
tests/test_zero/test_mem_collector.py
tests/test_zero/test_mem_collector.py
+74
-0
tests/test_zero/test_stateful_tensor_mgr.py
tests/test_zero/test_stateful_tensor_mgr.py
+19
-10
No files found.
colossalai/engine/ophooks/__init__.py
View file @
84c6700b
from
.utils
import
register_ophooks_recursively
,
BaseOpHook
from
._memtracer_ophook
import
MemTracerOpHook
__all__
=
[
"BaseOpHook"
,
"MemTracerOpHook"
,
"register_ophooks_recursively"
]
colossalai/utils/memory_tracer/memstats_collector.py
View file @
84c6700b
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.utils.memory_tracer
import
As
yncMemoryMonitor
from
colossalai.utils.memory_tracer
import
S
ync
Cuda
MemoryMonitor
import
torch
import
time
from
typing
import
List
...
...
@@ -19,7 +19,7 @@ class MemStatsCollector:
"""
def
__init__
(
self
)
->
None
:
self
.
_mem_monitor
=
As
yncMemoryMonitor
()
self
.
_mem_monitor
=
S
ync
Cuda
MemoryMonitor
()
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
...
...
@@ -31,9 +31,10 @@ class MemStatsCollector:
self
.
_sampling_time
=
[]
self
.
_start_flag
=
False
self
.
_period_idx
=
0
self
.
_step_idx
=
0
self
.
_step_total
=
0
def
overall_mem_stats
(
self
,
device_type
:
str
):
def
overall_mem_stats
(
self
,
device_type
:
str
)
->
List
[
int
]
:
if
device_type
==
'cuda'
:
return
self
.
_overall_cuda_list
elif
device_type
==
'cpu'
:
...
...
@@ -41,47 +42,23 @@ class MemStatsCollector:
else
:
raise
TypeError
def
model_data_list
(
self
,
device_type
:
str
,
unit
:
str
=
'B'
)
->
List
[
int
]:
if
unit
==
'GB'
:
scale
=
1e9
elif
unit
==
'MB'
:
scale
=
1e6
elif
unit
==
'KB'
:
scale
=
1e3
elif
unit
==
'B'
:
scale
=
1
else
:
raise
TypeError
def
model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
[
elem
/
scale
for
elem
in
self
.
_model_data_cuda_list
]
return
self
.
_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
[
elem
/
scale
for
elem
in
self
.
_model_data_cpu_list
]
else
:
raise
TypeError
def
non_model_data_list
(
self
,
device_type
:
str
,
unit
:
str
=
'B'
)
->
List
[
int
]:
"""Non model data stats
"""
if
unit
==
'GB'
:
scale
=
1e9
elif
unit
==
'MB'
:
scale
=
1e6
elif
unit
==
'KB'
:
scale
=
1e3
elif
unit
==
'B'
:
scale
=
1
return
self
.
_model_data_cpu_list
else
:
raise
TypeError
def
non_model_data_list
(
self
,
device_type
:
str
)
->
List
[
int
]:
if
device_type
==
'cuda'
:
return
[
elem
/
scale
for
elem
in
self
.
_non_model_data_cuda_list
]
return
self
.
_non_model_data_cuda_list
elif
device_type
==
'cpu'
:
return
[
elem
/
scale
for
elem
in
self
.
_non_model_data_cpu_list
]
return
self
.
_non_model_data_cpu_list
else
:
raise
TypeError
def
max
_non_model_data
(
self
,
device_type
:
str
)
->
int
:
def
next_period
_non_model_data
_usage
(
self
,
device_type
:
str
)
->
int
:
"""Get max non model data memory usage of current sampling period
Args:
...
...
@@ -91,12 +68,10 @@ class MemStatsCollector:
int: max non model data memory usage of current sampling period
"""
assert
not
self
.
_start_flag
,
'Cannot get mem stats info during collection phase.'
assert
len
(
self
.
_sampling_time
)
>
0
,
'Cannot get mem stats info before collection phase.'
next_period_idx
=
(
self
.
_period_idx
+
1
)
%
len
(
self
.
_sampling_time
)
current_non_model_data
=
self
.
non_model_data_list
(
device_type
)[
self
.
_period_idx
]
next_non_model_data
=
self
.
non_model_data_list
(
device_type
)[
next_period_idx
]
self
.
_period_idx
=
next_period_idx
return
max
(
current_non_model_data
,
next_non_model_data
)
assert
self
.
_step_total
>
0
,
'Cannot get mem stats info before collection phase.'
next_non_model_data
=
self
.
non_model_data_list
(
device_type
)[
self
.
_step_idx
]
self
.
_step_idx
=
(
self
.
_step_idx
+
1
)
%
self
.
_step_total
return
next_non_model_data
@
property
def
sampling_time
(
self
):
...
...
@@ -107,9 +82,37 @@ class MemStatsCollector:
self
.
_mem_monitor
.
start
()
def
finish_collection
(
self
):
self
.
sample_overall_data
()
self
.
_step_total
=
len
(
self
.
_sampling_time
)
self
.
_start_flag
=
False
self
.
_mem_monitor
.
finish
()
def
sample_model_data
(
self
)
->
None
:
"""Sampling model data statistics.
"""
if
self
.
_start_flag
:
cuda_mem
,
cpu_mem
=
GLOBAL_MODEL_DATA_TRACER
.
both_mem_usage
self
.
_model_data_cuda_list
.
append
(
cuda_mem
)
self
.
_model_data_cpu_list
.
append
(
cpu_mem
)
def
sample_overall_data
(
self
)
->
None
:
"""Sampling non model data statistics.
"""
if
self
.
_start_flag
:
# overall data recording is after model data recording
if
len
(
self
.
_model_data_cuda_list
)
==
0
:
return
self
.
_overall_cuda_list
.
append
(
self
.
_mem_monitor
.
finish
())
self
.
_overall_cpu_list
.
append
(
colo_device_memory_used
(
torch
.
device
(
'cpu'
)))
assert
len
(
self
.
_model_data_cuda_list
)
==
len
(
self
.
_overall_cuda_list
)
self
.
_non_model_data_cuda_list
.
append
(
self
.
_overall_cuda_list
[
-
1
]
-
self
.
_model_data_cuda_list
[
-
1
])
self
.
_non_model_data_cpu_list
.
append
(
self
.
_overall_cpu_list
[
-
1
]
-
self
.
_model_data_cpu_list
[
-
1
])
self
.
_sampling_time
.
append
(
time
.
time
())
self
.
_mem_monitor
.
start
()
def
sample_memstats
(
self
)
->
None
:
"""
Sampling memory statistics.
...
...
@@ -119,7 +122,7 @@ class MemStatsCollector:
if
self
.
_start_flag
:
self
.
_model_data_cuda_list
.
append
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
)
self
.
_overall_cuda_list
.
append
(
self
.
_mem_monitor
.
finish
())
self
.
_non_model_data_cuda_list
.
append
(
self
.
_
model_data
_cuda_list
[
-
1
]
-
self
.
_
overall
_cuda_list
[
-
1
])
self
.
_non_model_data_cuda_list
.
append
(
self
.
_
overall
_cuda_list
[
-
1
]
-
self
.
_
model_data
_cuda_list
[
-
1
])
self
.
_model_data_cpu_list
.
append
(
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
)
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
...
...
@@ -136,4 +139,5 @@ class MemStatsCollector:
self
.
_overall_cpu_list
=
[]
self
.
_start_flag
=
False
self
.
_period_idx
=
0
self
.
_step_idx
=
0
self
.
_step_total
=
0
colossalai/utils/memory_tracer/model_data_memtracer.py
View file @
84c6700b
...
...
@@ -101,5 +101,9 @@ class ModelDataTracer(metaclass=SingletonMeta):
cuda_usage
,
_
=
self
.
_get_mem_usage
()
return
cuda_usage
@
property
def
both_mem_usage
(
self
):
return
self
.
_get_mem_usage
()
GLOBAL_MODEL_DATA_TRACER
=
ModelDataTracer
()
colossalai/zero/sharded_param/sharded_param.py
View file @
84c6700b
...
...
@@ -109,6 +109,5 @@ class ShardedParamV2(object):
if
self
.
param
.
grad
is
not
None
and
self
.
param
.
grad
.
data_ptr
()
not
in
address_set
:
_update_mem_use
(
self
.
param
.
grad
)
address_set
.
add
(
self
.
param
.
grad
.
data_ptr
())
return
cuda_mem_use
,
cpu_mem_use
colossalai/zero/sharded_param/tensor_utils.py
View file @
84c6700b
...
...
@@ -13,7 +13,7 @@ def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[
cuda_use
,
cpu_use
=
0
,
0
mem_use
=
t
.
numel
()
*
t
.
element_size
()
mem_use
=
t
.
storage
().
size
()
*
t
.
element_size
()
if
t
.
device
.
type
==
'cuda'
:
cuda_use
+=
mem_use
elif
t
.
device
.
type
==
'cpu'
:
...
...
colossalai/zero/utils/stateful_tensor_mgr.py
View file @
84c6700b
...
...
@@ -38,10 +38,6 @@ class StatefulTensorMgr(object):
def
adjust_layout
(
self
)
->
None
:
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
Args:
mem_stats_collector (MemStatsCollector): a collector, usually owned by a Sharded Model.
It contains non-model footprint of a DNN model.
"""
# find stateful tensor in state COMPUTE
cuda_demand
=
0
...
...
colossalai/zero/utils/tensor_placement_policy.py
View file @
84c6700b
...
...
@@ -61,7 +61,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
max_cuda_non_model_data_per_period
=
cuda_capacity
*
self
.
_warmup_non_model_data_ratio
else
:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period
=
self
.
mem_stats_collector
.
max
_non_model_data
(
'cuda'
)
max_cuda_non_model_data_per_period
=
self
.
mem_stats_collector
.
next_period
_non_model_data
_usage
(
'cuda'
)
total_cuda_model_data
=
cuda_capacity
-
max_cuda_non_model_data_per_period
avail_cuda_model_data
=
total_cuda_model_data
-
used_cuda_model_data
if
avail_cuda_model_data
<
cuda_demand
:
...
...
@@ -71,7 +71,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
freed_cuda_model_data
=
0
to_free_tensor_list
=
hold_cuda_tensor_list
if
not
warmup
:
next_compute_idx
:
Dict
[
StatefulTensor
,
int
]
=
{
t
:
len
(
compute_list
)
for
t
in
hold_cuda_tensor_list
}
next_compute_idx
=
{
t
:
len
(
compute_list
)
for
t
in
hold_cuda_tensor_list
}
for
i
in
range
(
len
(
compute_list
)
-
1
,
compute_idx
,
-
1
):
if
compute_list
[
i
]
in
next_compute_idx
:
next_compute_idx
[
compute_list
[
i
]]
=
i
...
...
colossalai/zero/utils/zero_hook.py
View file @
84c6700b
...
...
@@ -36,17 +36,7 @@ class ZeroHook(BaseOpHook):
self
.
_memstarts_collector
=
memstarts_collector
self
.
_stateful_tensor_mgr
=
stateful_tensor_mgr
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
if
self
.
_stateful_tensor_mgr
:
self
.
_stateful_tensor_mgr
.
adjust_layout
()
else
:
for
param
in
module
.
parameters
(
recurse
=
False
):
colo_model_data_tensor_move_inline
(
param
.
colo_attr
.
sharded_data_tensor
,
self
.
computing_device
)
def
gather_parameters
(
self
,
module
:
torch
.
nn
.
Module
):
# gather sharded parameters
if
module
.
param_is_sharded
:
tensor_list
=
[]
...
...
@@ -55,10 +45,33 @@ class ZeroHook(BaseOpHook):
tensor_list
.
append
(
param
.
colo_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
gather
(
tensor_list
,
self
.
process_group
)
# record memory statistics
def
shard_parameters
(
self
,
module
:
torch
.
nn
.
Module
):
# shard gathered parameters
if
module
.
param_is_sharded
:
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'colo_attr'
)
tensor_list
.
append
(
param
.
colo_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
def
adjust_module_data
(
self
,
module
:
torch
.
nn
.
Module
):
# record overall data statistics
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_
memst
at
s
()
self
.
_memstarts_collector
.
sample_
overall_d
at
a
()
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
# adjust stateful tensor to get enough CUDA memory
self
.
_stateful_tensor_mgr
.
adjust_layout
()
# record model data statistics
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_model_data
()
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
self
.
adjust_module_data
(
module
)
self
.
gather_parameters
(
module
)
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
data
=
param
.
colo_attr
.
data_payload
assert
param
.
data
.
device
.
type
==
'cuda'
,
f
"PRE FWD param.data must be on CUDA"
...
...
@@ -69,41 +82,15 @@ class ZeroHook(BaseOpHook):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_FWD
)
# shard gathered parameters
if
module
.
param_is_sharded
:
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'colo_attr'
)
tensor_list
.
append
(
param
.
colo_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
self
.
shard_parameters
(
module
)
# remove torch payload
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
set_data_none
()
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
if
self
.
_stateful_tensor_mgr
:
self
.
_stateful_tensor_mgr
.
adjust_layout
()
else
:
for
param
in
module
.
parameters
(
recurse
=
False
):
colo_model_data_tensor_move_inline
(
param
.
colo_attr
.
sharded_data_tensor
,
self
.
computing_device
)
# gather sharded parameters
if
module
.
param_is_sharded
:
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'colo_attr'
)
tensor_list
.
append
(
param
.
colo_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
gather
(
tensor_list
,
self
.
process_group
)
# record memory statistics
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
self
.
adjust_module_data
(
module
)
self
.
gather_parameters
(
module
)
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
data
=
param
.
colo_attr
.
data_payload
assert
param
.
data
.
device
.
type
==
'cuda'
,
f
"PRE BWD param.data must be on CUDA"
...
...
@@ -114,13 +101,7 @@ class ZeroHook(BaseOpHook):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
colo_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
# shard gathered parameters
if
module
.
param_is_sharded
:
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'colo_attr'
)
tensor_list
.
append
(
param
.
colo_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
self
.
shard_parameters
(
module
)
# remove torch payload
for
param
in
module
.
parameters
(
recurse
=
False
):
...
...
tests/test_zero/test_mem_collector.py
0 → 100644
View file @
84c6700b
import
torch
import
colossalai
import
pytest
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory
import
colo_device_memory_capacity
,
colo_set_process_memory_fraction
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.shard_utils
import
BucketTensorShardStrategy
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_on_exception
from
functools
import
partial
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
proj1
=
nn
.
Linear
(
512
,
512
)
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
1024
,
512
))
self
.
proj2
=
nn
.
Linear
(
1024
,
512
)
def
forward
(
self
,
x
):
x
=
self
.
proj1
(
x
)
x
=
F
.
linear
(
x
,
self
.
weight
)
x
=
self
.
proj2
(
x
)
return
x
def
run_mem_collector_testing
():
cuda_capacity
=
colo_device_memory_capacity
(
get_current_device
())
fraction
=
(
50
*
1024
**
2
)
/
cuda_capacity
# limit max memory to 50MB
colo_set_process_memory_fraction
(
fraction
)
shard_strategy
=
BucketTensorShardStrategy
()
with
ZeroInitContext
(
target_device
=
get_current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
model
=
TestModel
()
model
=
ShardedModelV2
(
module
=
model
,
shard_strategy
=
shard_strategy
,
reduce_scatter_bucket_size_mb
=
1
,
tensor_placement_policy
=
'auto'
)
data
=
torch
.
randn
(
2
,
512
,
device
=
get_current_device
())
output
=
model
(
data
)
loss
=
torch
.
mean
(
output
)
model
.
backward
(
loss
)
cuda_model_data_list
=
model
.
_memstats_collector
.
model_data_list
(
'cuda'
)
assert
cuda_model_data_list
==
[
1311744
,
1836032
,
1836032
,
1311744
,
1836032
,
1836032
]
cuda_non_model_data_list
=
model
.
_memstats_collector
.
non_model_data_list
(
'cuda'
)
assert
cuda_non_model_data_list
[
0
]
>
cuda_non_model_data_list
[
1
]
assert
cuda_non_model_data_list
[
-
2
]
>
cuda_non_model_data_list
[
-
1
]
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_mem_collector_testing
()
@
pytest
.
mark
.
dist
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_mem_collector
(
world_size
=
2
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_mem_collector
()
tests/test_zero/test_stateful_tensor_mgr.py
View file @
84c6700b
...
...
@@ -48,30 +48,39 @@ def run_stm():
# warmup
# use naive eviction strategy
apply_adjust
(
model
,
model
.
p0
,
[
model
.
p0
],
stateful_tensor_mgr
)
mem_collector
.
sample_memstats
()
mem_collector
.
sample_model_data
()
mem_collector
.
sample_overall_data
()
apply_adjust
(
model
,
model
.
p1
,
[
model
.
p0
,
model
.
p1
],
stateful_tensor_mgr
)
mem_collector
.
sample_memstats
()
mem_collector
.
sample_model_data
()
mem_collector
.
sample_overall_data
()
apply_adjust
(
model
,
model
.
p2
,
[
model
.
p1
,
model
.
p2
],
stateful_tensor_mgr
)
mem_collector
.
sample_memstats
()
mem_collector
.
sample_model_data
()
mem_collector
.
sample_overall_data
()
apply_adjust
(
model
,
model
.
p0
,
[
model
.
p0
,
model
.
p2
],
stateful_tensor_mgr
)
mem_collector
.
sample_memstats
()
mem_collector
.
sample_model_data
()
mem_collector
.
sample_overall_data
()
apply_adjust
(
model
,
model
.
p1
,
[
model
.
p1
,
model
.
p2
],
stateful_tensor_mgr
)
mem_collector
.
sample_m
emst
at
s
()
mem_collector
.
sample_m
odel_d
at
a
()
mem_collector
.
finish_collection
()
stateful_tensor_mgr
.
reset
()
# warmup done
# use OPT-like eviction strategy
apply_adjust
(
model
,
model
.
p0
,
[
model
.
p0
,
model
.
p1
],
stateful_tensor_mgr
)
mem_collector
.
sample_memstats
()
mem_collector
.
sample_model_data
()
mem_collector
.
sample_overall_data
()
apply_adjust
(
model
,
model
.
p1
,
[
model
.
p0
,
model
.
p1
],
stateful_tensor_mgr
)
mem_collector
.
sample_memstats
()
mem_collector
.
sample_model_data
()
mem_collector
.
sample_overall_data
()
apply_adjust
(
model
,
model
.
p2
,
[
model
.
p0
,
model
.
p2
],
stateful_tensor_mgr
)
mem_collector
.
sample_memstats
()
mem_collector
.
sample_model_data
()
mem_collector
.
sample_overall_data
()
apply_adjust
(
model
,
model
.
p0
,
[
model
.
p0
,
model
.
p2
],
stateful_tensor_mgr
)
mem_collector
.
sample_memstats
()
mem_collector
.
sample_model_data
()
mem_collector
.
sample_overall_data
()
apply_adjust
(
model
,
model
.
p1
,
[
model
.
p1
,
model
.
p2
],
stateful_tensor_mgr
)
mem_collector
.
sample_memstats
()
mem_collector
.
sample_model_data
()
mem_collector
.
finish_collection
()
def
apply_adjust
(
model
:
torch
.
nn
.
Module
,
compute_param
:
Parameter
,
cuda_param_after_adjust
:
List
[
Parameter
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment