Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
31078171
Unverified
Commit
31078171
authored
Apr 25, 2022
by
HELSON
Committed by
GitHub
Apr 25, 2022
Browse files
[gemini] add stateful tensor container (#867)
parent
d01d3b8c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
205 additions
and
0 deletions
+205
-0
colossalai/gemini/stateful_tensor_container.py
colossalai/gemini/stateful_tensor_container.py
+131
-0
tests/test_gemini/test_stateful_tensor_container.py
tests/test_gemini/test_stateful_tensor_container.py
+74
-0
No files found.
colossalai/gemini/stateful_tensor_container.py
0 → 100644
View file @
31078171
import
queue
import
heapq
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
,
List
,
Dict
from
colossalai.gemini.stateful_tensor
import
StatefulTensor
,
TensorState
def
evict_check
(
st
:
StatefulTensor
)
->
bool
:
if
st
.
state
is
not
TensorState
.
COMPUTE
and
st
.
device
.
type
==
'cuda'
:
return
True
return
False
# Here ST means Stateful Tensor
class
BaseSTContainer
(
ABC
):
"""A type of container that store all potential stateful tensors which can be evicted from
CUDA. This kind of stateful tensor should satisfy two conditions. One is that it hasn't been
evicted, meaning the type of its device is CUDA, the other is that it isn't pinned in CUDA
memory, meaning its state isn't COMPUTE.
This container should get a stateful tensor when it become HOLD_LIKE from COMPUTE.
And it pops stateful tensors in function, `evict_tensors`.
In order to acquire an optimal eviction policy, users may need to offer computation step
index of each stateful tensor. So we can use a heap to maintain all potential evictable
statefule tensors. When poping, we can get the stateful tensor that used furthest in
current computation step.
"""
def
__init__
(
self
,
compute_step_dict
:
Dict
[
StatefulTensor
,
List
[
int
]],
total_step
:
int
):
self
.
compute_step_dict
=
compute_step_dict
self
.
total_step
=
total_step
@
abstractmethod
def
empty
(
self
)
->
bool
:
pass
@
abstractmethod
def
create
(
self
,
stateful_tensor_list
:
List
[
StatefulTensor
])
->
None
:
pass
@
abstractmethod
def
push
(
self
,
stateful_tensor
:
StatefulTensor
,
cur_step
:
int
)
->
None
:
pass
@
abstractmethod
def
pop
(
self
)
->
Optional
[
StatefulTensor
]:
pass
class
QueueSTContainer
(
BaseSTContainer
):
"""Queue type stateful tensor container. This is used in 'cpu' tensor placement policy.
It pops potential evictable stateful tensors in FIFO.
"""
def
__init__
(
self
,
compute_step_dict
:
Dict
[
StatefulTensor
,
List
[
int
]],
total_step
:
int
):
super
().
__init__
(
compute_step_dict
,
total_step
)
self
.
container
=
None
def
empty
(
self
)
->
bool
:
assert
self
.
container
is
not
None
return
self
.
container
.
empty
()
def
create
(
self
,
stateful_tensor_list
:
List
[
StatefulTensor
])
->
None
:
self
.
container
=
queue
.
SimpleQueue
()
for
stateful_tensor
in
stateful_tensor_list
:
self
.
container
.
put
(
stateful_tensor
)
def
push
(
self
,
stateful_tensor
:
StatefulTensor
,
cur_step
:
int
)
->
None
:
self
.
container
.
put
(
stateful_tensor
)
def
pop
(
self
)
->
Optional
[
StatefulTensor
]:
ret
=
None
while
not
self
.
empty
():
out_tensor
=
self
.
container
.
get
()
if
evict_check
(
out_tensor
):
ret
=
out_tensor
break
return
ret
class
HeapSTContainer
(
BaseSTContainer
):
"""Heap type stateful tensor container. This is used in 'auto' tensor placement policy.
It pops potential evictable stateful tensors in the order of the distance between current
step and next used step.
"""
def
__init__
(
self
,
compute_step_dict
:
Dict
[
StatefulTensor
,
List
[
int
]],
total_step
:
int
):
super
().
__init__
(
compute_step_dict
,
total_step
)
self
.
container
=
None
def
empty
(
self
)
->
bool
:
assert
self
.
container
is
not
None
return
self
.
container
==
[]
def
create
(
self
,
stateful_tensor_list
:
List
[
StatefulTensor
])
->
None
:
self
.
container
=
[]
for
stateful_tensor
in
stateful_tensor_list
:
# we want to pop the tensor which has the greatest next_step
# so the weight is next_step multiplied by -1
weight
=
-
self
.
__get_next_compute_step
(
stateful_tensor
,
-
1
)
self
.
container
.
append
((
weight
,
stateful_tensor
))
heapq
.
heapify
(
self
.
container
)
def
push
(
self
,
stateful_tensor
:
StatefulTensor
,
cur_step
:
int
)
->
None
:
# we want to pop the tensor which has the greatest next_step
# so the weight is next_step multiplied by -1
weight
=
-
self
.
__get_next_compute_step
(
stateful_tensor
,
cur_step
)
heapq
.
heappush
(
self
.
container
,
(
weight
,
stateful_tensor
))
def
pop
(
self
)
->
Optional
[
StatefulTensor
]:
ret
=
None
while
not
self
.
empty
():
_
,
out_tensor
=
heapq
.
heappop
(
self
.
container
)
if
evict_check
(
out_tensor
):
ret
=
out_tensor
break
return
ret
def
__get_next_compute_step
(
self
,
stateful_tensor
:
StatefulTensor
,
cur_step
:
int
):
# compute the id of next step
# if the tensor is not used in the furture
# next_step is set to the maximum
next_step
=
self
.
total_step
step_list
=
self
.
compute_step_dict
[
stateful_tensor
]
for
step
in
step_list
:
if
step
>
cur_step
:
next_step
=
step
break
return
next_step
tests/test_gemini/test_stateful_tensor_container.py
0 → 100644
View file @
31078171
import
pytest
import
torch
from
colossalai.gemini.stateful_tensor
import
TensorState
,
StatefulTensor
from
colossalai.gemini.stateful_tensor_container
import
QueueSTContainer
,
HeapSTContainer
@
pytest
.
mark
.
dist
def
test_stateful_tensor_container
():
st1
=
StatefulTensor
(
torch
.
randn
(
1
,
device
=
'cuda'
))
st2
=
StatefulTensor
(
torch
.
randn
(
2
,
device
=
'cuda'
))
st3
=
StatefulTensor
(
torch
.
randn
(
3
,
device
=
'cuda'
))
stateful_tensor_list
=
[
st1
,
st2
,
st3
]
step_list
=
[
st1
,
st2
,
st3
,
st3
,
st2
,
st1
]
compute_step_dict
=
dict
()
compute_step_dict
[
st1
]
=
[
0
,
5
]
compute_step_dict
[
st2
]
=
[
1
,
4
]
compute_step_dict
[
st3
]
=
[
2
,
3
]
def
run_queue_test
():
# test queue container
queue_container
=
QueueSTContainer
(
compute_step_dict
,
6
)
queue_container
.
create
(
stateful_tensor_list
)
res_list
=
[]
for
i
in
range
(
6
):
stateful_tensor
=
step_list
[
i
]
stateful_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
st_out
=
queue_container
.
pop
()
st_out
.
move_to
(
torch
.
device
(
'cpu'
))
res_list
.
append
(
st_out
.
payload
.
size
(
0
))
stateful_tensor
.
move_to
(
torch
.
device
(
'cuda'
))
queue_container
.
push
(
stateful_tensor
,
i
)
stateful_tensor
.
trans_state
(
TensorState
.
HOLD
)
assert
res_list
==
[
2
,
3
,
1
,
2
,
3
,
2
]
run_queue_test
()
def
run_heap_test
():
# test heap container
st1
.
move_to
(
torch
.
device
(
'cuda'
))
st2
.
move_to
(
torch
.
device
(
'cuda'
))
st3
.
move_to
(
torch
.
device
(
'cuda'
))
heap_container
=
HeapSTContainer
(
compute_step_dict
,
6
)
heap_container
.
create
(
stateful_tensor_list
)
res_list
=
[]
for
i
in
range
(
6
):
stateful_tensor
=
step_list
[
i
]
stateful_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
st_out
=
heap_container
.
pop
()
if
st_out
is
not
None
:
res_list
.
append
(
st_out
.
payload
.
size
(
0
))
st_out
.
move_to
(
torch
.
device
(
'cpu'
))
stateful_tensor
.
move_to
(
torch
.
device
(
'cuda'
))
heap_container
.
push
(
stateful_tensor
,
i
)
stateful_tensor
.
trans_state
(
TensorState
.
HOLD
)
assert
res_list
==
[
3
,
1
,
2
,
3
,
2
]
run_heap_test
()
if
__name__
==
'__main__'
:
test_stateful_tensor_container
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment