Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
dcca614e
Unverified
Commit
dcca614e
authored
Apr 14, 2022
by
ver217
Committed by
GitHub
Apr 14, 2022
Browse files
[hotfix] fix test_stateful_tensor_mgr (#762)
parent
6978980f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
13 deletions
+20
-13
colossalai/zero/utils/tensor_placement_policy.py
colossalai/zero/utils/tensor_placement_policy.py
+2
-2
tests/test_zero/test_stateful_tensor_mgr.py
tests/test_zero/test_stateful_tensor_mgr.py
+18
-11
No files found.
colossalai/zero/utils/tensor_placement_policy.py
View file @
dcca614e
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Dict
from
typing
import
List
,
Optional
import
torch
import
torch
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.zero.sharded_param.tensor_utils
import
colo_model_data_tensor_move_inline
,
colo_tensor_mem_usage
from
colossalai.zero.sharded_param.tensor_utils
import
colo_model_data_tensor_move_inline
,
colo_tensor_mem_usage
...
@@ -79,7 +79,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
...
@@ -79,7 +79,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
next_compute_idx
=
sorted
(
next_compute_idx
.
items
(),
key
=
lambda
pair
:
pair
[
1
],
reverse
=
True
)
next_compute_idx
=
sorted
(
next_compute_idx
.
items
(),
key
=
lambda
pair
:
pair
[
1
],
reverse
=
True
)
to_free_tensor_list
=
[
t
for
(
t
,
idx
)
in
next_compute_idx
]
to_free_tensor_list
=
[
t
for
(
t
,
idx
)
in
next_compute_idx
]
for
t
in
to_free_tensor_list
:
for
t
in
to_free_tensor_list
:
if
freed_cuda_model_data
>
to_free_cuda_model_data
:
if
freed_cuda_model_data
>
=
to_free_cuda_model_data
:
break
break
freed_cuda_model_data
+=
colo_tensor_mem_usage
(
t
)[
0
]
freed_cuda_model_data
+=
colo_tensor_mem_usage
(
t
)[
0
]
colo_model_data_tensor_move_inline
(
t
,
torch
.
device
(
'cpu'
))
colo_model_data_tensor_move_inline
(
t
,
torch
.
device
(
'cpu'
))
...
...
tests/test_zero/test_stateful_tensor_mgr.py
View file @
dcca614e
...
@@ -5,7 +5,7 @@ import torch.multiprocessing as mp
...
@@ -5,7 +5,7 @@ import torch.multiprocessing as mp
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory_tracer
import
MemStatsCollector
from
colossalai.utils.memory_tracer
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory
import
colo_device_memory_capacity
,
colo_set_process_memory_fraction
from
colossalai.utils.memory
import
colo_set_process_memory_fraction
from
colossalai.zero.utils
import
StatefulTensorMgr
from
colossalai.zero.utils
import
StatefulTensorMgr
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
...
@@ -21,18 +21,22 @@ class Net(torch.nn.Module):
...
@@ -21,18 +21,22 @@ class Net(torch.nn.Module):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
# each parameter is
5
12 MB
# each parameter is 12
8
MB
self
.
p0
=
Parameter
(
torch
.
empty
(
1024
,
1024
,
128
))
self
.
p0
=
Parameter
(
torch
.
empty
(
1024
,
1024
,
32
))
self
.
p1
=
Parameter
(
torch
.
empty
(
1024
,
1024
,
128
))
self
.
p1
=
Parameter
(
torch
.
empty
(
1024
,
1024
,
32
))
self
.
p2
=
Parameter
(
torch
.
empty
(
1024
,
1024
,
128
))
self
.
p2
=
Parameter
(
torch
.
empty
(
1024
,
1024
,
32
))
def
run_stm
():
def
limit_cuda_memory
(
memory_in_g
:
float
):
cuda_capacity
=
colo_device_memory_capacity
(
get_current_device
())
cuda_capacity
=
torch
.
cuda
.
get_device_properties
(
get_current_device
()).
total_memory
fraction
=
(
1.4
*
1024
**
3
)
/
cuda_capacity
fraction
=
(
memory_in_g
*
1024
**
3
)
/
cuda_capacity
# limit max memory to 1.4GB
# which means only 2 parameters can be on CUDA
colo_set_process_memory_fraction
(
fraction
)
colo_set_process_memory_fraction
(
fraction
)
def
run_stm
():
# warmup phase use 20% CUDA memory to store params
# only 2 params can be on CUDA
limit_cuda_memory
(
1.26
)
model
=
Net
()
model
=
Net
()
for
p
in
model
.
parameters
():
for
p
in
model
.
parameters
():
p
.
colo_attr
=
ShardedParamV2
(
p
,
set_data_none
=
True
)
p
.
colo_attr
=
ShardedParamV2
(
p
,
set_data_none
=
True
)
...
@@ -65,6 +69,8 @@ def run_stm():
...
@@ -65,6 +69,8 @@ def run_stm():
stateful_tensor_mgr
.
reset
()
stateful_tensor_mgr
.
reset
()
# warmup done
# warmup done
# only 2 params can be on CUDA
limit_cuda_memory
(
0.26
)
# use OPT-like eviction strategy
# use OPT-like eviction strategy
apply_adjust
(
model
,
model
.
p0
,
[
model
.
p0
,
model
.
p1
],
stateful_tensor_mgr
)
apply_adjust
(
model
,
model
.
p0
,
[
model
.
p0
,
model
.
p1
],
stateful_tensor_mgr
)
mem_collector
.
sample_model_data
()
mem_collector
.
sample_model_data
()
...
@@ -112,7 +118,7 @@ def run_dist(rank, world_size, port):
...
@@ -112,7 +118,7 @@ def run_dist(rank, world_size, port):
run_stm
()
run_stm
()
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
gpu
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_stateful_tensor_manager
(
world_size
=
1
):
def
test_stateful_tensor_manager
(
world_size
=
1
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
...
@@ -120,4 +126,5 @@ def test_stateful_tensor_manager(world_size=1):
...
@@ -120,4 +126,5 @@ def test_stateful_tensor_manager(world_size=1):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
# this unit test can pass if available CUDA memory >= 1.5G
test_stateful_tensor_manager
()
test_stateful_tensor_manager
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment