Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
8274ca23
"docs/vscode:/vscode.git/clone" did not exist on "b4cc0f331f2e550344855b1666478ea5ac2cf118"
Unverified
Commit
8274ca23
authored
Jun 04, 2023
by
Woosuk Kwon
Committed by
GitHub
Jun 04, 2023
Browse files
Add docstrings for LLM (#137)
parent
62ec38ea
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
10 deletions
+66
-10
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+2
-2
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+5
-3
cacheflow/entrypoints/llm.py
cacheflow/entrypoints/llm.py
+57
-4
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+2
-1
No files found.
benchmarks/benchmark_latency.py
View file @
8274ca23
...
@@ -30,7 +30,6 @@ def main(args: argparse.Namespace):
...
@@ -30,7 +30,6 @@ def main(args: argparse.Namespace):
max_tokens
=
args
.
output_len
,
max_tokens
=
args
.
output_len
,
)
)
print
(
sampling_params
)
print
(
sampling_params
)
dummy_prompts
=
[
""
]
*
args
.
batch_size
dummy_prompt_token_ids
=
[[
0
]
*
args
.
input_len
]
*
args
.
batch_size
dummy_prompt_token_ids
=
[[
0
]
*
args
.
input_len
]
*
args
.
batch_size
def
run_to_completion
(
profile
:
bool
=
False
):
def
run_to_completion
(
profile
:
bool
=
False
):
...
@@ -38,7 +37,8 @@ def main(args: argparse.Namespace):
...
@@ -38,7 +37,8 @@ def main(args: argparse.Namespace):
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
time
()
start_time
=
time
.
time
()
llm
.
generate
(
dummy_prompts
,
sampling_params
,
dummy_prompt_token_ids
,
llm
.
generate
(
prompt_token_ids
=
dummy_prompt_token_ids
,
sampling_params
=
sampling_params
,
use_tqdm
=
False
)
use_tqdm
=
False
)
end_time
=
time
.
time
()
end_time
=
time
.
time
()
...
...
benchmarks/benchmark_throughput.py
View file @
8274ca23
...
@@ -72,9 +72,9 @@ def main(args: argparse.Namespace):
...
@@ -72,9 +72,9 @@ def main(args: argparse.Namespace):
)
)
# FIXME(woosuk): Do not use internal method.
# FIXME(woosuk): Do not use internal method.
llm
.
_add_request
(
llm
.
_add_request
(
prompt
=
""
,
prompt
=
None
,
sampling_params
=
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
)
)
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -85,7 +85,9 @@ def main(args: argparse.Namespace):
...
@@ -85,7 +85,9 @@ def main(args: argparse.Namespace):
len
(
prompt_token_ids
)
+
output_len
len
(
prompt_token_ids
)
+
output_len
for
prompt_token_ids
,
output_len
in
requests
for
prompt_token_ids
,
output_len
in
requests
)
)
print
(
f
"Throughput:
{
total_num_tokens
/
(
end
-
start
):.
2
f
}
tokens/s"
)
elapsed_time
=
end
-
start
print
(
f
"Throughput:
{
len
(
requests
)
/
elapsed_time
:.
2
f
}
requests/s, "
f
"
{
total_num_tokens
/
elapsed_time
:.
2
f
}
tokens/s"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
cacheflow/entrypoints/llm.py
View file @
8274ca23
...
@@ -11,6 +11,28 @@ from cacheflow.utils import Counter
...
@@ -11,6 +11,28 @@ from cacheflow.utils import Counter
class
LLM
:
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMServer` class instead.
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float16` and `bfloat16`. If `default`, we use the
`torch_dtype` attribute of the model config. If the `torch_dtype`
is `float32`, we use `float16` instead.
seed: The seed to initialize the random number generator for sampling.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -39,19 +61,50 @@ class LLM:
...
@@ -39,19 +61,50 @@ class LLM:
def
generate
(
def
generate
(
self
,
self
,
prompts
:
Union
[
str
,
List
[
str
]],
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]
]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
if
prompts
is
None
and
prompt_token_ids
is
None
:
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
"provided."
)
if
isinstance
(
prompts
,
str
):
if
isinstance
(
prompts
,
str
):
# Convert a single prompt to a list.
prompts
=
[
prompts
]
prompts
=
[
prompts
]
if
prompts
is
not
None
and
prompt_token_ids
is
not
None
:
if
len
(
prompts
)
!=
len
(
prompt_token_ids
):
raise
ValueError
(
"The lengths of prompts and prompt_token_ids "
"must be the same."
)
if
sampling_params
is
None
:
if
sampling_params
is
None
:
# Use default sampling params.
# Use default sampling params.
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
()
# Add requests to the server.
# Add requests to the server.
for
i
in
range
(
len
(
prompts
)):
if
prompts
is
not
None
:
prompt
=
prompts
[
i
]
num_requests
=
len
(
prompts
)
else
:
num_requests
=
len
(
prompt_token_ids
)
for
i
in
range
(
num_requests
):
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
if
prompt_token_ids
is
None
:
if
prompt_token_ids
is
None
:
token_ids
=
None
token_ids
=
None
else
:
else
:
...
@@ -61,7 +114,7 @@ class LLM:
...
@@ -61,7 +114,7 @@ class LLM:
def
_add_request
(
def
_add_request
(
self
,
self
,
prompt
:
str
,
prompt
:
Optional
[
str
]
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]],
prompt_token_ids
:
Optional
[
List
[
int
]],
)
->
None
:
)
->
None
:
...
...
cacheflow/server/llm_server.py
View file @
8274ca23
...
@@ -126,7 +126,7 @@ class LLMServer:
...
@@ -126,7 +126,7 @@ class LLMServer:
def
add_request
(
def
add_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
str
,
prompt
:
Optional
[
str
]
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
...
@@ -134,6 +134,7 @@ class LLMServer:
...
@@ -134,6 +134,7 @@ class LLMServer:
if
arrival_time
is
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
if
prompt_token_ids
is
None
:
if
prompt_token_ids
is
None
:
assert
prompt
is
not
None
prompt_token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
prompt_token_ids
=
self
.
tokenizer
.
encode
(
prompt
)
# Create the sequences.
# Create the sequences.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment