Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d1557e66
Unverified
Commit
d1557e66
authored
Nov 17, 2024
by
wchen61
Committed by
GitHub
Nov 17, 2024
Browse files
[Misc] Enhance offline_inference to support user-configurable paramet… (#10392)
Signed-off-by:
wchen61
<
wchen61@foxmail.com
>
parent
80d85c5d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
78 additions
and
20 deletions
+78
-20
examples/offline_inference.py
examples/offline_inference.py
+78
-20
No files found.
examples/offline_inference.py
View file @
d1557e66
from
dataclasses
import
asdict
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.utils
import
FlexibleArgumentParser
# Sample prompts.
def
get_prompts
(
num_prompts
:
int
):
prompts
=
[
# The default sample prompts.
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
"The president of the United States is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The capital of France is"
,
"The future of AI is"
,
"The future of AI is"
,
]
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
if
num_prompts
!=
len
(
prompts
):
prompts
=
(
prompts
*
((
num_prompts
//
len
(
prompts
))
+
1
))[:
num_prompts
]
# Create an LLM.
llm
=
LLM
(
model
=
"facebook/opt-125m"
)
return
prompts
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
def
main
(
args
):
# Print the outputs.
# Create prompts
for
output
in
outputs
:
prompts
=
get_prompts
(
args
.
num_prompts
)
# Create a sampling params object.
sampling_params
=
SamplingParams
(
n
=
args
.
n
,
temperature
=
args
.
temperature
,
top_p
=
args
.
top_p
,
top_k
=
args
.
top_k
,
max_tokens
=
args
.
max_tokens
)
# Create an LLM.
# The default model is 'facebook/opt-125m'
engine_args
=
EngineArgs
.
from_cli_args
(
args
)
llm
=
LLM
(
**
asdict
(
engine_args
))
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
if
__name__
==
'__main__'
:
parser
=
FlexibleArgumentParser
()
parser
=
EngineArgs
.
add_cli_args
(
parser
)
group
=
parser
.
add_argument_group
(
"SamplingParams options"
)
group
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
4
,
help
=
"Number of prompts used for inference"
)
group
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
16
,
help
=
"Generated output length for sampling"
)
group
.
add_argument
(
'--n'
,
type
=
int
,
default
=
1
,
help
=
'Number of generated sequences per prompt'
)
group
.
add_argument
(
'--temperature'
,
type
=
float
,
default
=
0.8
,
help
=
'Temperature for text generation'
)
group
.
add_argument
(
'--top-p'
,
type
=
float
,
default
=
0.95
,
help
=
'top_p for text generation'
)
group
.
add_argument
(
'--top-k'
,
type
=
int
,
default
=-
1
,
help
=
'top_k for text generation'
)
args
=
parser
.
parse_args
()
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment