Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
8274ca23
Unverified
Commit
8274ca23
authored
Jun 04, 2023
by
Woosuk Kwon
Committed by
GitHub
Jun 04, 2023
Browse files
Add docstrings for LLM (#137)
parent
62ec38ea
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
10 deletions
+66
-10
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+2
-2
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+5
-3
cacheflow/entrypoints/llm.py
cacheflow/entrypoints/llm.py
+57
-4
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+2
-1
No files found.
benchmarks/benchmark_latency.py
View file @
8274ca23
...
@@ -30,7 +30,6 @@ def main(args: argparse.Namespace):
...
@@ -30,7 +30,6 @@ def main(args: argparse.Namespace):
max_tokens
=
args
.
output_len
,
max_tokens
=
args
.
output_len
,
)
)
print
(
sampling_params
)
print
(
sampling_params
)
dummy_prompts
=
[
""
]
*
args
.
batch_size
dummy_prompt_token_ids
=
[[
0
]
*
args
.
input_len
]
*
args
.
batch_size
dummy_prompt_token_ids
=
[[
0
]
*
args
.
input_len
]
*
args
.
batch_size
def
run_to_completion
(
profile
:
bool
=
False
):
def
run_to_completion
(
profile
:
bool
=
False
):
...
@@ -38,7 +37,8 @@ def main(args: argparse.Namespace):
...
@@ -38,7 +37,8 @@ def main(args: argparse.Namespace):
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
time
()
start_time
=
time
.
time
()
llm
.
generate
(
dummy_prompts
,
sampling_params
,
dummy_prompt_token_ids
,
llm
.
generate
(
prompt_token_ids
=
dummy_prompt_token_ids
,
sampling_params
=
sampling_params
,
use_tqdm
=
False
)
use_tqdm
=
False
)
end_time
=
time
.
time
()
end_time
=
time
.
time
()
...
...
benchmarks/benchmark_throughput.py
View file @
8274ca23
...
@@ -72,9 +72,9 @@ def main(args: argparse.Namespace):
...
@@ -72,9 +72,9 @@ def main(args: argparse.Namespace):
)
)
# FIXME(woosuk): Do not use internal method.
# FIXME(woosuk): Do not use internal method.
llm
.
_add_request
(
llm
.
_add_request
(
prompt
=
""
,
prompt
=
None
,
sampling_params
=
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
)
)
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -85,7 +85,9 @@ def main(args: argparse.Namespace):
...
@@ -85,7 +85,9 @@ def main(args: argparse.Namespace):
len
(
prompt_token_ids
)
+
output_len
len
(
prompt_token_ids
)
+
output_len
for
prompt_token_ids
,
output_len
in
requests
for
prompt_token_ids
,
output_len
in
requests
)
)
print
(
f
"Throughput:
{
total_num_tokens
/
(
end
-
start
):.
2
f
}
tokens/s"
)
elapsed_time
=
end
-
start
print
(
f
"Throughput:
{
len
(
requests
)
/
elapsed_time
:.
2
f
}
requests/s, "
f
"
{
total_num_tokens
/
elapsed_time
:.
2
f
}
tokens/s"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
cacheflow/entrypoints/llm.py
View file @
8274ca23
...
@@ -11,6 +11,28 @@ from cacheflow.utils import Counter
...
@@ -11,6 +11,28 @@ from cacheflow.utils import Counter
class
LLM
:
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMServer` class instead.
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float16` and `bfloat16`. If `default`, we use the
`torch_dtype` attribute of the model config. If the `torch_dtype`
is `float32`, we use `float16` instead.
seed: The seed to initialize the random number generator for sampling.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -39,19 +61,50 @@ class LLM:
...
@@ -39,19 +61,50 @@ class LLM:
def
generate
(
def
generate
(
self
,
self
,
prompts
:
Union
[
str
,
List
[
str
]],
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]
]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
if
prompts
is
None
and
prompt_token_ids
is
None
:
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
"provided."
)
if
isinstance
(
prompts
,
str
):
if
isinstance
(
prompts
,
str
):
# Convert a single prompt to a list.
prompts
=
[
prompts
]
prompts
=
[
prompts
]
if
prompts
is
not
None
and
prompt_token_ids
is
not
None
:
if
len
(
prompts
)
!=
len
(
prompt_token_ids
):
raise
ValueError
(
"The lengths of prompts and prompt_token_ids "
"must be the same."
)
if
sampling_params
is
None
:
if
sampling_params
is
None
:
# Use default sampling params.
# Use default sampling params.
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
()
# Add requests to the server.
# Add requests to the server.
for
i
in
range
(
len
(
prompts
)):
if
prompts
is
not
None
:
prompt
=
prompts
[
i
]
num_requests
=
len
(
prompts
)
else
:
num_requests
=
len
(
prompt_token_ids
)
for
i
in
range
(
num_requests
):
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
if
prompt_token_ids
is
None
:
if
prompt_token_ids
is
None
:
token_ids
=
None
token_ids
=
None
else
:
else
:
...
@@ -61,7 +114,7 @@ class LLM:
...
@@ -61,7 +114,7 @@ class LLM:
def
_add_request
(
def
_add_request
(
self
,
self
,
prompt
:
str
,
prompt
:
Optional
[
str
]
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]],
prompt_token_ids
:
Optional
[
List
[
int
]],
)
->
None
:
)
->
None
:
...
...
cacheflow/server/llm_server.py
View file @
8274ca23
...
@@ -126,7 +126,7 @@ class LLMServer:
...
@@ -126,7 +126,7 @@ class LLMServer:
def
add_request
(
def
add_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
str
,
prompt
:
Optional
[
str
]
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
...
@@ -134,6 +134,7 @@ class LLMServer:
...
@@ -134,6 +134,7 @@ class LLMServer:
if
arrival_time
is
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
if
prompt_token_ids
is
None
:
if
prompt_token_ids
is
None
:
assert
prompt
is
not
None
prompt_token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
prompt_token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
# Create the sequences.
# Create the sequences.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment