Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
cfae35b8
Unverified
Commit
cfae35b8
authored
Mar 13, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 13, 2023
Browse files
Add miscellaneous updates (#8)
parent
e9d3f2ff
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
44 additions
and
22 deletions
+44
-22
cacheflow/master/scheduler.py
cacheflow/master/scheduler.py
+8
-7
cacheflow/models/attention.py
cacheflow/models/attention.py
+5
-6
cacheflow/models/memory_analyzer.py
cacheflow/models/memory_analyzer.py
+14
-4
cacheflow/models/sample.py
cacheflow/models/sample.py
+1
-1
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+7
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+5
-2
server.py
server.py
+4
-2
No files found.
cacheflow/master/scheduler.py
View file @
cfae35b8
...
@@ -158,8 +158,8 @@ class Scheduler:
...
@@ -158,8 +158,8 @@ class Scheduler:
# 3. Join new sequences if possible.
# 3. Join new sequences if possible.
# NOTE: Here we implicitly assume FCFS scheduling.
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
# TODO(woosuk): Add a batching policy to control the batch size.
self
.
_fetch_inputs
()
if
not
self
.
swapped
:
if
not
self
.
swapped
:
self
.
_fetch_inputs
()
for
i
,
seq_group
in
enumerate
(
self
.
pending
):
for
i
,
seq_group
in
enumerate
(
self
.
pending
):
num_prompt_tokens
=
seq_group
.
seqs
[
0
].
get_len
()
num_prompt_tokens
=
seq_group
.
seqs
[
0
].
get_len
()
if
self
.
block_manager
.
can_allocate
(
seq_group
):
if
self
.
block_manager
.
can_allocate
(
seq_group
):
...
@@ -211,12 +211,13 @@ class Scheduler:
...
@@ -211,12 +211,13 @@ class Scheduler:
input_seq_groups
.
append
(
input_seq_group
)
input_seq_groups
.
append
(
input_seq_group
)
# 5. Execute the first stage of the pipeline.
# 5. Execute the first stage of the pipeline.
self
.
controllers
[
0
].
execute_stage
(
if
(
input_seq_groups
or
blocks_to_swap_in
or
blocks_to_swap_out
):
input_seq_groups
,
self
.
controllers
[
0
].
execute_stage
(
blocks_to_swap_in
,
input_seq_groups
,
blocks_to_swap_out
,
blocks_to_swap_in
,
blocks_to_copy
,
blocks_to_swap_out
,
)
blocks_to_copy
,
)
def
post_step
(
def
post_step
(
self
,
self
,
...
...
cacheflow/models/attention.py
View file @
cfae35b8
...
@@ -12,7 +12,7 @@ from cacheflow.models import InputMetadata
...
@@ -12,7 +12,7 @@ from cacheflow.models import InputMetadata
class
OPTCacheFlowAttention
(
nn
.
Module
):
class
OPTCacheFlowAttention
(
nn
.
Module
):
def
__init__
(
self
,
scale
:
float
)
->
None
:
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
().
__init__
()
super
(
OPTCacheFlowAttention
,
self
).
__init__
()
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
flash_attn
=
FlashAttention
(
softmax_scale
=
self
.
scale
)
self
.
flash_attn
=
FlashAttention
(
softmax_scale
=
self
.
scale
)
...
@@ -106,8 +106,8 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -106,8 +106,8 @@ class OPTCacheFlowAttention(nn.Module):
output
=
output
.
view
(
-
1
,
num_heads
,
head_size
)
output
=
output
.
view
(
-
1
,
num_heads
,
head_size
)
# Compute the attention op for prompts.
# Compute the attention op for prompts.
if
input_metadata
.
num_prompt
s
>
0
:
num_prompt_tokens
=
input_metadata
.
num_prompt
_tokens
num_prompt_tokens
=
sum
(
input_metadata
.
prompt_lens
)
if
num_prompt_tokens
>
0
:
self
.
multi_query_kv_attention
(
self
.
multi_query_kv_attention
(
output
[:
num_prompt_tokens
],
output
[:
num_prompt_tokens
],
query
[:
num_prompt_tokens
],
query
[:
num_prompt_tokens
],
...
@@ -126,10 +126,9 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -126,10 +126,9 @@ class OPTCacheFlowAttention(nn.Module):
if
input_metadata
.
num_generation_tokens
>
0
:
if
input_metadata
.
num_generation_tokens
>
0
:
# Compute the attention op for generation tokens.
# Compute the attention op for generation tokens.
start_idx
=
sum
(
input_metadata
.
prompt_lens
)
self
.
single_query_cached_kv_attention
(
self
.
single_query_cached_kv_attention
(
output
[
start_idx
:],
output
[
num_prompt_tokens
:],
query
[
start_idx
:],
query
[
num_prompt_tokens
:],
key_cache
,
key_cache
,
value_cache
,
value_cache
,
input_metadata
)
input_metadata
)
...
...
cacheflow/models/memory_analyzer.py
View file @
cfae35b8
...
@@ -5,7 +5,7 @@ from cacheflow.models.utils import get_cpu_memory
...
@@ -5,7 +5,7 @@ from cacheflow.models.utils import get_cpu_memory
from
cacheflow.models.utils
import
get_dtype_size
from
cacheflow.models.utils
import
get_dtype_size
from
cacheflow.models.utils
import
get_gpu_memory
from
cacheflow.models.utils
import
get_gpu_memory
_GiB
=
1
<<
30
_GiB
=
1
<<
30
class
CacheFlowMemoryAnalyzer
:
class
CacheFlowMemoryAnalyzer
:
...
@@ -117,9 +117,19 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
...
@@ -117,9 +117,19 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
def
get_max_num_cpu_blocks
(
def
get_max_num_cpu_blocks
(
self
,
self
,
memory_utilization
:
float
=
0.25
,
swap_space
:
int
,
)
->
int
:
)
->
int
:
swap_space
=
swap_space
*
_GiB
cpu_memory
=
get_cpu_memory
()
cpu_memory
=
get_cpu_memory
()
usable_memory
=
int
(
memory_utilization
*
cpu_memory
)
if
swap_space
>
0.8
*
cpu_memory
:
max_num_blocks
=
usable_memory
//
self
.
_get_cache_block_size
()
raise
ValueError
(
f
'The swap space (
{
swap_space
/
_GiB
:.
2
f
}
GiB) '
'takes more than 80% of the available memory '
f
'(
{
cpu_memory
/
_GiB
:.
2
f
}
GiB).'
'Please check the swap space size.'
)
if
swap_space
>
0.5
*
cpu_memory
:
print
(
f
'WARNING: The swap space (
{
swap_space
/
_GiB
:.
2
f
}
GiB) '
'takes more than 50% of the available memory '
f
'(
{
cpu_memory
/
_GiB
:.
2
f
}
GiB).'
'This may slow the system performance.'
)
max_num_blocks
=
swap_space
//
self
.
_get_cache_block_size
()
return
max_num_blocks
return
max_num_blocks
cacheflow/models/sample.py
View file @
cfae35b8
...
@@ -11,7 +11,7 @@ from cacheflow.sequence import SequenceOutputs
...
@@ -11,7 +11,7 @@ from cacheflow.sequence import SequenceOutputs
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
(
Sampler
,
self
).
__init__
()
def
forward
(
def
forward
(
self
,
self
,
...
...
cacheflow/worker/worker.py
View file @
cfae35b8
...
@@ -191,6 +191,13 @@ class Worker:
...
@@ -191,6 +191,13 @@ class Worker:
else
:
else
:
cache_events
=
None
cache_events
=
None
# If there is no input, we don't need to execute the model.
if
not
input_seq_groups
:
if
cache_events
is
not
None
:
for
event
in
cache_events
:
event
.
wait
()
return
{}
# Prepare input tensors.
# Prepare input tensors.
input_tokens
,
input_positions
,
input_metadata
=
self
.
prepare_inputs
(
input_tokens
,
input_positions
,
input_metadata
=
self
.
prepare_inputs
(
input_seq_groups
)
input_seq_groups
)
...
...
csrc/cache_kernels.cu
View file @
cfae35b8
#include <torch/extension.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <algorithm>
#include <algorithm>
...
@@ -73,6 +72,8 @@ void copy_blocks(
...
@@ -73,6 +72,8 @@ void copy_blocks(
}
}
}
}
namespace
cacheflow
{
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
reshape_and_cache_kernel
(
__global__
void
reshape_and_cache_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
...
@@ -112,6 +113,8 @@ __global__ void reshape_and_cache_kernel(
...
@@ -112,6 +113,8 @@ __global__ void reshape_and_cache_kernel(
}
}
}
}
}
// namespace cacheflow
void
reshape_and_cache
(
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
value
,
...
@@ -131,7 +134,7 @@ void reshape_and_cache(
...
@@ -131,7 +134,7 @@ void reshape_and_cache(
key
.
scalar_type
(),
key
.
scalar_type
(),
"reshape_and_cache_kernel"
,
"reshape_and_cache_kernel"
,
[
&
]
{
[
&
]
{
reshape_and_cache_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
cacheflow
::
reshape_and_cache_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
...
...
server.py
View file @
cfae35b8
...
@@ -15,7 +15,8 @@ parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='
...
@@ -15,7 +15,8 @@ parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--max-batch-size'
,
type
=
int
,
default
=
2048
,
help
=
'maximum number of batched tokens'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
20
,
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--max-batch-size'
,
type
=
int
,
default
=
2560
,
help
=
'maximum number of batched tokens'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -27,7 +28,8 @@ def main():
...
@@ -27,7 +28,8 @@ def main():
)
)
num_gpu_blocks
=
memory_analyzer
.
get_max_num_gpu_blocks
(
num_gpu_blocks
=
memory_analyzer
.
get_max_num_gpu_blocks
(
max_num_batched_tokens
=
args
.
max_batch_size
)
max_num_batched_tokens
=
args
.
max_batch_size
)
num_cpu_blocks
=
memory_analyzer
.
get_max_num_cpu_blocks
()
num_cpu_blocks
=
memory_analyzer
.
get_max_num_cpu_blocks
(
swap_space
=
args
.
swap_space
)
print
(
f
'# GPU blocks:
{
num_gpu_blocks
}
, # CPU blocks:
{
num_cpu_blocks
}
'
)
print
(
f
'# GPU blocks:
{
num_gpu_blocks
}
, # CPU blocks:
{
num_cpu_blocks
}
'
)
# Create a controller for each node.
# Create a controller for each node.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment