Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c2d6d2f9
Unverified
Commit
c2d6d2f9
authored
Jun 02, 2024
by
Daniil Arapov
Committed by
GitHub
Jun 01, 2024
Browse files
[Bugfix]: Fix issues related to prefix caching example (#5177) (#5180)
parent
8279078e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
10 deletions
+37
-10
examples/offline_inference_with_prefix.py
examples/offline_inference_with_prefix.py
+37
-10
No files found.
examples/offline_inference_with_prefix.py
View file @
c2d6d2f9
from
time
import
time
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
# Common prefix.
prefix
=
(
prefix
=
(
"You are an expert school principal, skilled in effectively managing "
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"faculty and staff. Draft 10-15 questions for a potential first grade "
...
@@ -18,36 +21,60 @@ prompts = [
...
@@ -18,36 +21,60 @@ prompts = [
"The capital of France is"
,
"The capital of France is"
,
"The future of AI is"
,
"The future of AI is"
,
]
]
generating_prompts
=
[
prefix
+
prompt
for
prompt
in
prompts
]
# Create a sampling params object.
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.0
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
)
# Create an LLM.
# Create an LLM.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
enable_prefix_caching
=
True
)
regular_
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
gpu_memory_utilization
=
0.4
)
generating_prompts
=
[
prefix
+
prompt
for
prompt
in
prompts
]
prefix_cached_llm
=
LLM
(
model
=
"facebook/opt-125m"
,
enable_prefix_caching
=
True
,
gpu_memory_utilization
=
0.4
)
print
(
"Results without `enable_prefix_caching`"
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
# that contain the prompt, generated text, and other information.
outputs
=
llm
.
generate
(
generating_prompts
,
sampling_params
)
start_time_regular
=
time
()
outputs
=
regular_llm
.
generate
(
generating_prompts
,
sampling_params
)
duration_regular
=
time
()
-
start_time_regular
regular_generated_texts
=
[]
# Print the outputs.
# Print the outputs.
for
output
in
outputs
:
for
output
in
outputs
:
prompt
=
output
.
prompt
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
generated_text
=
output
.
outputs
[
0
].
text
regular_generated_texts
.
append
(
generated_text
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
print
(
"-"
*
80
)
print
(
"-"
*
80
)
# The llm.generate call will batch all prompts and send the batch at once
# The llm.generate call will batch all prompts and send the batch at once
# if resources allow.
The prefix will only be cached after the first batch
# if resources allow.
# is processed, so we need to call generate once to calculate the prefix
start_time_cached
=
time
()
# and cache it.
outputs
=
prefix_cached_llm
.
generate
(
generating_prompts
,
sampling_params
)
outputs
=
llm
.
generate
(
generating_prompts
[
0
],
sampling_params
)
duration_cached
=
time
()
-
start_time_cached
# Subsequent batches can leverage the cached prefix
print
(
"Results with `enable_prefix_caching`"
)
outputs
=
llm
.
generate
(
generating_prompts
,
sampling_params
)
# Print the outputs. You should see the same outputs as before
cached_generated_texts
=
[]
# Print the outputs. You should see the same outputs as before.
for
output
in
outputs
:
for
output
in
outputs
:
prompt
=
output
.
prompt
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
generated_text
=
output
.
outputs
[
0
].
text
cached_generated_texts
.
append
(
generated_text
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
print
(
"-"
*
80
)
# Compare the results and display the speedup
generated_same
=
all
([
regular_generated_texts
[
i
]
==
cached_generated_texts
[
i
]
for
i
in
range
(
len
(
prompts
))
])
print
(
f
"Generated answers are the same:
{
generated_same
}
"
)
speedup
=
round
(
duration_regular
/
duration_cached
,
2
)
print
(
f
"Speed up of cached generation compared to the regular is:
{
speedup
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment