Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
211318d4
Unverified
Commit
211318d4
authored
May 28, 2023
by
Woosuk Kwon
Committed by
GitHub
May 28, 2023
Browse files
Add throughput benchmarking script (#133)
parent
337871c6
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
145 additions
and
257 deletions
+145
-257
benchmark/benchmark_attention.py
benchmark/benchmark_attention.py
+0
-165
benchmark/benchmark_cache.py
benchmark/benchmark_cache.py
+0
-81
benchmarks/README.md
benchmarks/README.md
+8
-0
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+0
-0
benchmarks/benchmark_text_completion.py
benchmarks/benchmark_text_completion.py
+0
-0
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+104
-0
benchmarks/trace.py
benchmarks/trace.py
+0
-0
cacheflow/__init__.py
cacheflow/__init__.py
+2
-1
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+3
-0
cacheflow/entrypoints/llm.py
cacheflow/entrypoints/llm.py
+24
-8
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+3
-0
examples/simple_server.py
examples/simple_server.py
+1
-2
No files found.
benchmark/benchmark_attention.py
deleted
100644 → 0
View file @
337871c6
import
functools
import
random
import
time
from
typing
import
List
from
flash_attn.flash_attn_interface
import
_flash_attn_forward
import
torch
from
cacheflow
import
attention_ops
def
benchmark
(
name
,
f
,
num_warmup
=
10
,
num_iters
=
100
):
for
_
in
range
(
num_warmup
):
f
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
_
in
range
(
num_iters
):
f
()
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
print
(
f
'
{
name
}
:
{
(
end
-
start
)
/
num_iters
*
1000
:.
3
f
}
ms'
)
@
torch
.
inference_mode
()
def
benchmark_multi_query_cached_kv_attention
(
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
print
(
f
'query_lens:
{
query_lens
}
, context_lens:
{
context_lens
}
, '
f
'num_heads:
{
num_heads
}
, head_size:
{
head_size
}
, block_size: '
f
'
{
block_size
}
, num_blocks:
{
num_blocks
}
, dtype:
{
dtype
}
'
)
# Create query tensor.
num_queries
=
len
(
query_lens
)
cu_query_lens
=
[
0
]
for
query_len
in
query_lens
:
cu_query_lens
.
append
(
cu_query_lens
[
-
1
]
+
query_len
)
num_total_tokens
=
cu_query_lens
[
-
1
]
qkv
=
torch
.
randn
(
num_total_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
query
,
_
,
_
=
qkv
.
unbind
(
dim
=
1
)
# Create key and value cache.
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_block_shape
=
(
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
randn
(
size
=
(
num_blocks
,
*
key_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
value_block_shape
=
(
num_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
randn
(
size
=
(
num_blocks
,
*
value_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
# Create block tables.
max_context_len
=
max
(
context_lens
)
max_num_blocks_per_seq
=
(
max_context_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_queries
):
block_table
=
[
random
.
randint
(
0
,
num_blocks
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
# Create input and output data structures.
cu_query_lens
=
torch
.
tensor
(
cu_query_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
context_len_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
output
=
torch
.
empty
(
num_total_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
# Run our implementation.
def
run_ours
():
attention_ops
.
multi_query_cached_kv_attention
(
cu_query_lens
,
output
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_len_tensor
,
block_size
,
max_context_len
,
)
benchmark
(
'Ours'
,
run_ours
)
# Upper bound: Flash attention.
# Becuase Flash attention cannot read our own cache,
# we make key and value tensors contiguous.
num_kv_tokens
=
sum
(
context_lens
)
cu_context_lens
=
[
0
]
for
context_len
in
context_lens
:
cu_context_lens
.
append
(
cu_context_lens
[
-
1
]
+
context_len
)
cu_context_lens
=
torch
.
tensor
(
cu_context_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
qkv
=
torch
.
randn
(
num_kv_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
ref_output
=
torch
.
empty_like
(
output
)
# Run Flash attention.
def
run_flash_attn
():
_flash_attn_forward
(
query
,
key
,
value
,
ref_output
,
cu_query_lens
,
cu_context_lens
,
max
(
query_lens
),
max_context_len
,
dropout_p
=
0.0
,
softmax_scale
=
scale
,
causal
=
True
,
return_softmax
=
False
,
)
benchmark
(
'Flash attention'
,
run_flash_attn
)
if
__name__
==
'__main__'
:
BLOCK_SIZE
=
8
NUM_BLOCKS
=
1024
DTYPE
=
torch
.
half
# LLaMA-13B and OPT-13B
NUM_HEADS
=
40
HEAD_SIZE
=
128
run_benchmark
=
functools
.
partial
(
benchmark_multi_query_cached_kv_attention
,
num_heads
=
NUM_HEADS
,
head_size
=
HEAD_SIZE
,
block_size
=
BLOCK_SIZE
,
num_blocks
=
NUM_BLOCKS
,
dtype
=
DTYPE
,
)
run_benchmark
(
query_lens
=
[
64
]
*
1
,
context_lens
=
[
64
]
*
1
,
)
run_benchmark
(
query_lens
=
[
128
]
*
1
,
context_lens
=
[
128
]
*
1
,
)
run_benchmark
(
query_lens
=
[
64
]
*
8
,
context_lens
=
[
64
]
*
8
,
)
run_benchmark
(
query_lens
=
[
128
]
*
8
,
context_lens
=
[
128
]
*
8
,
)
run_benchmark
(
query_lens
=
[
64
,
32
,
16
],
context_lens
=
[
128
,
256
,
64
],
)
run_benchmark
(
query_lens
=
[
1024
],
context_lens
=
[
1024
],
)
benchmark/benchmark_cache.py
deleted
100644 → 0
View file @
337871c6
import
functools
import
random
import
time
import
torch
from
cacheflow
import
cache_ops
def
benchmark
(
name
,
f
,
size
:
int
,
num_warmup
=
10
,
num_iters
=
100
):
for
_
in
range
(
num_warmup
):
f
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
_
in
range
(
num_iters
):
f
()
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
avg_time
=
(
end
-
start
)
/
num_iters
print
(
f
'[Latency]
{
name
}
:
{
avg_time
*
1000
:.
3
f
}
ms'
)
print
(
f
'[Throughput]
{
name
}
:
{
size
/
avg_time
/
2
**
30
:.
3
f
}
GB/s'
)
@
torch
.
inference_mode
()
def
test_gather_cached_kv
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
print
(
f
'num_tokens:
{
num_tokens
}
, num_heads:
{
num_heads
}
, '
f
'head_size:
{
head_size
}
, block_size:
{
block_size
}
, '
f
'num_blocks:
{
num_blocks
}
, dtype:
{
dtype
}
'
)
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
randn
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
randn
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
# Run Flash attention.
def
run
():
cache_ops
.
gather_cached_kv
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
benchmark
(
'gather_cached_kv'
,
run
,
size
=
num_tokens
*
num_heads
*
head_size
*
2
*
qkv
.
element_size
())
if
__name__
==
'__main__'
:
BLOCK_SIZE
=
8
NUM_BLOCKS
=
1024
DTYPE
=
torch
.
half
# LLaMA-13B and OPT-13B
NUM_HEADS
=
40
HEAD_SIZE
=
128
run_benchmark
=
functools
.
partial
(
test_gather_cached_kv
,
num_heads
=
NUM_HEADS
,
head_size
=
HEAD_SIZE
,
block_size
=
BLOCK_SIZE
,
num_blocks
=
NUM_BLOCKS
,
dtype
=
DTYPE
,
)
for
i
in
range
(
6
,
12
):
run_benchmark
(
num_tokens
=
2
**
i
)
benchmarks/README.md
0 → 100644
View file @
211318d4
# Benchmarking CacheFlow
## Downloading the ShareGPT dataset
You can download the dataset by running:
```
bash
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```
benchmark/benchmark_latency.py
→
benchmark
s
/benchmark_latency.py
View file @
211318d4
File moved
benchmark/benchmark_text_completion.py
→
benchmark
s
/benchmark_text_completion.py
View file @
211318d4
File moved
benchmarks/benchmark_throughput.py
0 → 100644
View file @
211318d4
import
argparse
import
json
import
random
import
time
from
typing
import
List
,
Tuple
from
cacheflow
import
LLM
,
SamplingParams
from
transformers
import
PreTrainedTokenizerBase
def
sample_requests
(
dataset_path
:
str
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
)
->
List
[
Tuple
[
List
[
int
],
int
]]:
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
dataset
=
json
.
load
(
f
)
# Filter out the conversations with less than 2 turns.
dataset
=
[
data
for
data
in
dataset
if
len
(
data
[
"conversations"
])
>=
2
]
# Only keep the first two turns of each conversation.
dataset
=
[
(
data
[
"conversations"
][
0
][
"value"
],
data
[
"conversations"
][
1
][
"value"
])
for
data
in
dataset
]
# Tokenize the prompts and completions.
prompts
=
[
prompt
for
prompt
,
_
in
dataset
]
prompt_token_ids
=
tokenizer
(
prompts
).
input_ids
completions
=
[
completion
for
_
,
completion
in
dataset
]
completion_token_ids
=
tokenizer
(
completions
).
input_ids
tokenized_dataset
=
[]
for
i
in
range
(
len
(
dataset
)):
output_len
=
len
(
completion_token_ids
[
i
])
tokenized_dataset
.
append
((
prompt_token_ids
[
i
],
output_len
))
# Filter out if the prompt length + output length is greater than 2048.
tokenized_dataset
=
[
(
prompt_token_ids
,
output_len
)
for
prompt_token_ids
,
output_len
in
tokenized_dataset
if
len
(
prompt_token_ids
)
+
output_len
<=
2048
]
# Sample the requests.
sampled_requests
=
random
.
sample
(
tokenized_dataset
,
num_requests
)
return
sampled_requests
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
random
.
seed
(
args
.
seed
)
llm
=
LLM
(
model
=
args
.
model
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
seed
=
args
.
seed
,
)
tokenizer
=
llm
.
get_tokenizer
()
requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
)
# Add the requests to the server.
for
prompt_token_ids
,
output_len
in
requests
:
sampling_params
=
SamplingParams
(
n
=
args
.
n
,
temperature
=
0.0
if
args
.
use_beam_search
else
1.0
,
top_p
=
1.0
,
use_beam_search
=
args
.
use_beam_search
,
ignore_eos
=
True
,
max_tokens
=
output_len
,
)
# FIXME(woosuk): Do not use internal method.
llm
.
_add_request
(
prompt
=
""
,
sampling_params
=
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
)
start
=
time
.
time
()
# FIXME(woosuk): Do use internal method.
llm
.
_run_server
(
use_tqdm
=
True
)
end
=
time
.
time
()
total_num_tokens
=
sum
(
len
(
prompt_token_ids
)
+
output_len
for
prompt_token_ids
,
output_len
in
requests
)
print
(
f
"Throughput:
{
total_num_tokens
/
(
end
-
start
):.
2
f
}
tokens/s"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the throughput."
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
required
=
True
,
help
=
"Path to the dataset."
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"facebook/opt-125m"
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1
,
help
=
"Number of generated sequences per prompt."
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
1000
,
help
=
"Number of prompts to process."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
main
(
args
)
benchmark/trace.py
→
benchmark
s
/trace.py
View file @
211318d4
File moved
cacheflow/__init__.py
View file @
211318d4
from
cacheflow.entrypoints.llm
import
LLM
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.outputs
import
RequestOutput
,
CompletionOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.llm_server
import
LLMServer
...
...
@@ -9,6 +9,7 @@ __all__ = [
"LLM"
,
"SamplingParams"
,
"RequestOutput"
,
"CompletionOutput"
,
"LLMServer"
,
"ServerArgs"
,
"initialize_cluster"
,
...
...
cacheflow/core/scheduler.py
View file @
211318d4
...
...
@@ -87,6 +87,9 @@ class Scheduler:
def
has_unfinished_seqs
(
self
)
->
bool
:
return
self
.
waiting
or
self
.
running
or
self
.
swapped
def
get_num_unfinished_seq_groups
(
self
)
->
int
:
return
len
(
self
.
waiting
)
+
len
(
self
.
running
)
+
len
(
self
.
swapped
)
def
_schedule
(
self
)
->
Tuple
[
SchedulerOutputs
,
List
[
str
]]:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
...
...
cacheflow/entrypoints/llm.py
View file @
211318d4
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Union
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
tqdm
import
tqdm
from
cacheflow.outputs
import
RequestOutput
...
...
@@ -31,6 +32,11 @@ class LLM:
self
.
llm_server
=
LLMServer
.
from_server_args
(
server_args
)
self
.
request_counter
=
Counter
()
def
get_tokenizer
(
self
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
return
self
.
llm_server
.
tokenizer
def
generate
(
self
,
prompts
:
List
[
str
],
...
...
@@ -41,10 +47,6 @@ class LLM:
if
sampling_params
is
None
:
# Use default sampling params.
sampling_params
=
SamplingParams
()
# Initialize tqdm.
if
use_tqdm
:
pbar
=
tqdm
(
total
=
len
(
prompts
),
desc
=
"Processed prompts"
)
# Add requests to the server.
for
i
in
range
(
len
(
prompts
)):
prompt
=
prompts
[
i
]
...
...
@@ -52,10 +54,24 @@ class LLM:
token_ids
=
None
else
:
token_ids
=
prompt_token_ids
[
i
]
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_server
.
add_request
(
request_id
,
prompt
,
sampling_params
,
token_ids
)
self
.
_add_request
(
prompt
,
sampling_params
,
token_ids
)
return
self
.
_run_server
(
use_tqdm
)
def
_add_request
(
self
,
prompt
:
str
,
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]],
)
->
None
:
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_server
.
add_request
(
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
)
def
_run_server
(
self
,
use_tqdm
:
bool
)
->
List
[
RequestOutput
]:
# Initialize tqdm.
if
use_tqdm
:
num_requests
=
self
.
llm_server
.
get_num_unfinished_requests
()
pbar
=
tqdm
(
total
=
num_requests
,
desc
=
"Processed prompts"
)
# Run the server.
outputs
:
List
[
RequestOutput
]
=
[]
while
self
.
llm_server
.
has_unfinished_requests
():
...
...
cacheflow/server/llm_server.py
View file @
211318d4
...
...
@@ -151,6 +151,9 @@ class LLMServer:
# Add the sequence group to the scheduler.
self
.
scheduler
.
add_seq_group
(
seq_group
)
def
get_num_unfinished_requests
(
self
)
->
int
:
return
self
.
scheduler
.
get_num_unfinished_seq_groups
()
def
has_unfinished_requests
(
self
)
->
bool
:
return
self
.
scheduler
.
has_unfinished_seqs
()
...
...
examples/simple_server.py
View file @
211318d4
...
...
@@ -19,9 +19,8 @@ def main(args: argparse.Namespace):
SamplingParams
(
n
=
3
,
best_of
=
3
,
use_beam_search
=
True
,
temperature
=
0.0
)),
]
request_id
=
0
# Run the server.
request_id
=
0
while
True
:
# To test iteration-level scheduling, we add one request at each step.
if
test_prompts
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment