Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
0f4b3219
Unverified
Commit
0f4b3219
authored
Apr 15, 2023
by
Woosuk Kwon
Committed by
GitHub
Apr 15, 2023
Browse files
Support various block sizes & Change default block size to 16 (#38)
parent
84eee24e
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
602 additions
and
619 deletions
+602
-619
benchmark/benchmark_text_completion.py
benchmark/benchmark_text_completion.py
+1
-0
cacheflow/master/block_manager.py
cacheflow/master/block_manager.py
+0
-3
cacheflow/master/scheduler.py
cacheflow/master/scheduler.py
+2
-1
cacheflow/master/server.py
cacheflow/master/server.py
+2
-2
csrc/attention.cpp
csrc/attention.cpp
+0
-16
csrc/attention_kernels.cu
csrc/attention_kernels.cu
+557
-579
csrc/cuda_primitives.h
csrc/cuda_primitives.h
+40
-18
No files found.
benchmark/benchmark_text_completion.py
View file @
0f4b3219
...
...
@@ -268,6 +268,7 @@ if __name__ == '__main__':
f
'
{
model_name
}
-tp
{
args
.
tensor_parallel_size
}
'
,
sample_dir
,
'cacheflow'
,
f
'block
{
args
.
block_size
}
'
,
f
'req-rate-
{
args
.
request_rate
}
'
,
f
'seed
{
args
.
seed
}
'
,
f
'duration-
{
args
.
duration
}
'
,
...
...
cacheflow/master/block_manager.py
View file @
0f4b3219
...
...
@@ -15,9 +15,6 @@ class BlockAllocator:
block_size
:
int
,
num_blocks
:
int
,
)
->
None
:
if
block_size
not
in
[
8
,
16
,
32
]:
raise
ValueError
(
f
'Unsupported block size:
{
block_size
}
'
'The block size must be one of {8, 16, 32}.'
)
self
.
device
=
device
self
.
block_size
=
block_size
self
.
num_blocks
=
num_blocks
...
...
cacheflow/master/scheduler.py
View file @
0f4b3219
...
...
@@ -125,7 +125,8 @@ class Scheduler:
# Swap in the sequence groups in the SWAPPED state if possible.
self
.
swapped
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
swapped
)
while
self
.
swapped
:
# FCFS
while
self
.
swapped
and
not
blocks_to_swap_out
:
seq_group
=
self
.
swapped
[
0
]
# If the sequence group has been preempted in this step, stop.
if
seq_group
in
preempted
:
...
...
cacheflow/master/server.py
View file @
0f4b3219
...
...
@@ -180,9 +180,9 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
8
,
choices
=
[
8
,
16
,
32
],
help
=
'token block size'
)
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
16
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
],
help
=
'data type'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
20
,
help
=
'CPU swap space size (GiB) per GPU'
)
...
...
csrc/attention.cpp
View file @
0f4b3219
...
...
@@ -11,25 +11,9 @@ void single_query_cached_kv_attention(
int
block_size
,
int
max_context_len
);
void
multi_query_cached_kv_attention
(
torch
::
Tensor
&
cu_query_lens
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int
block_size
,
int
max_context_len
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"single_query_cached_kv_attention"
,
&
single_query_cached_kv_attention
,
"Compute the attention between an input query and the cached key/value tensors"
);
m
.
def
(
"multi_query_cached_kv_attention"
,
&
multi_query_cached_kv_attention
,
"Compute the attention between multiple input queries and the cached key/value tensors"
);
}
csrc/attention_kernels.cu
View file @
0f4b3219
This diff is collapsed.
Click to expand it.
csrc/cuda_primitives.h
View file @
0f4b3219
...
...
@@ -1074,6 +1074,21 @@ inline __device__ float sum(Float8_ v)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
dot
(
float
a
,
float
b
)
{
return
a
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
dot
(
float2
a
,
float2
b
)
{
float2
c
=
mul
<
float2
,
float2
,
float2
>
(
a
,
b
);
return
c
.
x
+
c
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
dot
(
Float4_
a
,
Float4_
b
)
{
float2
acc
=
mul
<
float2
,
float2
,
float2
>
(
a
.
x
,
b
.
x
);
...
...
@@ -1253,37 +1268,44 @@ inline __device__ float convert_to_float(uint4 u)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
cast_to_float
(
float
u
)
{
return
u
;
}
//
inline __device__ float cast_to_float(float u)
//
{
//
return u;
//
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
cast_to_float
(
float2
u
)
{
return
u
;
}
//
inline __device__ float2 cast_to_float(float2 u)
//
{
//
return u;
//
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
cast_to_float
(
float4
u
)
{
return
u
;
}
//
inline __device__ float4 cast_to_float(float4 u)
//
{
//
return u;
//
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
cast_to_float
(
Float4_
u
)
{
return
u
;
}
// inline __device__ Float4_ cast_to_float(Float4_ u)
// {
// return u;
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
// inline __device__ Float8_ cast_to_float(Float8_ u)
// {
// return u;
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
F
loat
8_
cast_to_float
(
Float8_
u
)
inline
__device__
f
loat
cast_to_float
(
uint16_t
u
)
{
return
u
;
return
half_to_float
(
u
)
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment