Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
b9926f7f
"tests/nn/pipe_process/test_inplace.py" did not exist on "0cd65242a0e43c60251abb3b631411e5ea5b6b86"
Unverified
Commit
b9926f7f
authored
Apr 09, 2023
by
Woosuk Kwon
Committed by
GitHub
Apr 09, 2023
Browse files
Support block size 32 (#35)
parent
ee88a7e5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
5 deletions
+49
-5
cacheflow/master/block_manager.py
cacheflow/master/block_manager.py
+2
-2
cacheflow/master/server.py
cacheflow/master/server.py
+1
-1
csrc/attention_kernels.cu
csrc/attention_kernels.cu
+44
-0
tests/kernels/attention.py
tests/kernels/attention.py
+2
-2
No files found.
cacheflow/master/block_manager.py
View file @
b9926f7f
...
...
@@ -15,9 +15,9 @@ class BlockAllocator:
block_size
:
int
,
num_blocks
:
int
,
)
->
None
:
if
block_size
not
in
[
8
,
16
]:
if
block_size
not
in
[
8
,
16
,
32
]:
raise
ValueError
(
f
'Unsupported block size:
{
block_size
}
'
'The block size must be
either 8 or 16
.'
)
'The block size must be
one of {8, 16, 32}
.'
)
self
.
device
=
device
self
.
block_size
=
block_size
self
.
num_blocks
=
num_blocks
...
...
cacheflow/master/server.py
View file @
b9926f7f
...
...
@@ -174,7 +174,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
8
,
choices
=
[
8
,
16
],
help
=
'token block size'
)
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
8
,
choices
=
[
8
,
16
,
32
],
help
=
'token block size'
)
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
...
...
csrc/attention_kernels.cu
View file @
b9926f7f
...
...
@@ -654,6 +654,16 @@ void single_query_cached_kv_attention(
block_tables
,
context_lens
,
max_context_len
);
}
else
if
(
block_size
==
32
)
{
single_query_cached_kv_attention_launcher
<
uint16_t
,
32
>
(
out
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
max_context_len
);
}
else
{
assert
(
false
);
}
...
...
@@ -679,6 +689,16 @@ void single_query_cached_kv_attention(
block_tables
,
context_lens
,
max_context_len
);
}
else
if
(
block_size
==
32
)
{
single_query_cached_kv_attention_launcher
<
float
,
32
>
(
out
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
max_context_len
);
}
else
{
assert
(
false
);
}
...
...
@@ -834,6 +854,18 @@ void multi_query_cached_kv_attention(
block_tables
,
context_lens
,
max_context_len
);
}
else
if
(
block_size
==
32
)
{
multi_query_cached_kv_attention_launcher
<
uint16_t
,
32
>
(
cu_query_lens
,
seq_prompt_mapping
,
out
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
max_context_len
);
}
else
{
assert
(
false
);
}
...
...
@@ -863,6 +895,18 @@ void multi_query_cached_kv_attention(
block_tables
,
context_lens
,
max_context_len
);
}
else
if
(
block_size
==
32
)
{
multi_query_cached_kv_attention_launcher
<
float
,
32
>
(
cu_query_lens
,
seq_prompt_mapping
,
out
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
max_context_len
);
}
else
{
assert
(
false
);
}
...
...
tests/kernels/attention.py
View file @
b9926f7f
...
...
@@ -350,7 +350,7 @@ def test_attention(seed: int) -> None:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
block_size
in
[
8
,
16
]:
for
block_size
in
[
8
,
16
,
32
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing single_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
...
...
@@ -368,7 +368,7 @@ def test_attention(seed: int) -> None:
# note that the test is also more likely to fail due to the much
# larger amount of tokens in the input may increase the variance.
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
block_size
in
[
8
,
16
]:
for
block_size
in
[
8
,
16
,
32
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing multi_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment