Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
b9926f7f
Unverified
Commit
b9926f7f
authored
Apr 09, 2023
by
Woosuk Kwon
Committed by
GitHub
Apr 09, 2023
Browse files
Support block size 32 (#35)
parent
ee88a7e5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
5 deletions
+49
-5
cacheflow/master/block_manager.py
cacheflow/master/block_manager.py
+2
-2
cacheflow/master/server.py
cacheflow/master/server.py
+1
-1
csrc/attention_kernels.cu
csrc/attention_kernels.cu
+44
-0
tests/kernels/attention.py
tests/kernels/attention.py
+2
-2
No files found.
cacheflow/master/block_manager.py
View file @
b9926f7f
...
...
@@ -15,9 +15,9 @@ class BlockAllocator:
block_size
:
int
,
num_blocks
:
int
,
)
->
None
:
if
block_size
not
in
[
8
,
16
]:
if
block_size
not
in
[
8
,
16
,
32
]:
raise
ValueError
(
f
'Unsupported block size:
{
block_size
}
'
'The block size must be
either 8 or 16
.'
)
'The block size must be
one of {8, 16, 32}
.'
)
self
.
device
=
device
self
.
block_size
=
block_size
self
.
num_blocks
=
num_blocks
...
...
cacheflow/master/server.py
View file @
b9926f7f
...
...
@@ -174,7 +174,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
8
,
choices
=
[
8
,
16
],
help
=
'token block size'
)
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
8
,
choices
=
[
8
,
16
,
32
],
help
=
'token block size'
)
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
...
...
csrc/attention_kernels.cu
View file @
b9926f7f
...
...
@@ -654,6 +654,16 @@ void single_query_cached_kv_attention(
block_tables
,
context_lens
,
max_context_len
);
}
else
if
(
block_size
==
32
)
{
single_query_cached_kv_attention_launcher
<
uint16_t
,
32
>
(
out
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
max_context_len
);
}
else
{
assert
(
false
);
}
...
...
@@ -679,6 +689,16 @@ void single_query_cached_kv_attention(
block_tables
,
context_lens
,
max_context_len
);
}
else
if
(
block_size
==
32
)
{
single_query_cached_kv_attention_launcher
<
float
,
32
>
(
out
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
max_context_len
);
}
else
{
assert
(
false
);
}
...
...
@@ -834,6 +854,18 @@ void multi_query_cached_kv_attention(
block_tables
,
context_lens
,
max_context_len
);
}
else
if
(
block_size
==
32
)
{
multi_query_cached_kv_attention_launcher
<
uint16_t
,
32
>
(
cu_query_lens
,
seq_prompt_mapping
,
out
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
max_context_len
);
}
else
{
assert
(
false
);
}
...
...
@@ -863,6 +895,18 @@ void multi_query_cached_kv_attention(
block_tables
,
context_lens
,
max_context_len
);
}
else
if
(
block_size
==
32
)
{
multi_query_cached_kv_attention_launcher
<
float
,
32
>
(
cu_query_lens
,
seq_prompt_mapping
,
out
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
max_context_len
);
}
else
{
assert
(
false
);
}
...
...
tests/kernels/attention.py
View file @
b9926f7f
...
...
@@ -350,7 +350,7 @@ def test_attention(seed: int) -> None:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
block_size
in
[
8
,
16
]:
for
block_size
in
[
8
,
16
,
32
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing single_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
...
...
@@ -368,7 +368,7 @@ def test_attention(seed: int) -> None:
# note that the test is also more likely to fail due to the much
# larger amount of tokens in the input may increase the variance.
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
block_size
in
[
8
,
16
]:
for
block_size
in
[
8
,
16
,
32
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing multi_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment