Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
49c5e0ec
Unverified
Commit
49c5e0ec
authored
Jul 20, 2024
by
yichuan~
Committed by
GitHub
Jul 19, 2024
Browse files
Add support for OpenAI API parallel sampling (#640)
parent
ec2150b2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
381 additions
and
168 deletions
+381
-168
examples/usage/openai_parallel_sample.py
examples/usage/openai_parallel_sample.py
+75
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+22
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+214
-115
python/sglang/srt/openai_api_adapter.py
python/sglang/srt/openai_api_adapter.py
+68
-48
python/sglang/srt/sampling_params.py
python/sglang/srt/sampling_params.py
+2
-0
No files found.
examples/usage/openai_parallel_sample.py
0 → 100644
View file @
49c5e0ec
import
openai
client
=
openai
.
Client
(
base_url
=
"http://127.0.0.1:30000/v1"
,
api_key
=
"EMPTY"
)
# Text completion
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
"I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little"
,
n
=
1
,
temperature
=
0.8
,
max_tokens
=
32
,
)
print
(
response
)
# Text completion
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
"I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little"
,
n
=
3
,
temperature
=
0.8
,
max_tokens
=
32
,
)
print
(
response
)
# Text completion
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
[
"The name of the famous soccer player is "
,
"The capital of US is"
],
n
=
1
,
temperature
=
0.8
,
max_tokens
=
32
,
)
print
(
response
)
# Text completion
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
[
"The name of the famous soccer player is "
,
"The capital of US is"
],
n
=
3
,
temperature
=
0.8
,
max_tokens
=
32
,
)
print
(
response
)
# Text completion
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
[
"The capital of France is"
,
"The capital of Germany is"
,
"The capital of US is"
,
],
n
=
3
,
temperature
=
0.8
,
max_tokens
=
32
,
)
print
(
response
)
# Chat completion
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"user"
,
"content"
:
"List 3 countries and their capitals."
},
],
temperature
=
0.8
,
max_tokens
=
64
,
logprobs
=
True
,
n
=
4
,
)
print
(
response
)
python/sglang/srt/managers/io_struct.py
View file @
49c5e0ec
...
@@ -40,11 +40,13 @@ class GenerateReqInput:
...
@@ -40,11 +40,13 @@ class GenerateReqInput:
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
):
):
raise
ValueError
(
"Either text or input_ids should be provided."
)
raise
ValueError
(
"Either text or input_ids should be provided."
)
if
"n"
in
self
.
sampling_params
and
self
.
sampling_params
[
"n"
]
!=
1
:
if
self
.
text
is
not
None
:
is_single
=
False
is_single
=
isinstance
(
self
.
text
,
str
)
else
:
else
:
is_single
=
isinstance
(
self
.
input_ids
[
0
],
int
)
if
self
.
text
is
not
None
:
is_single
=
isinstance
(
self
.
text
,
str
)
else
:
is_single
=
isinstance
(
self
.
input_ids
[
0
],
int
)
self
.
is_single
=
is_single
self
.
is_single
=
is_single
if
is_single
:
if
is_single
:
...
@@ -59,7 +61,22 @@ class GenerateReqInput:
...
@@ -59,7 +61,22 @@ class GenerateReqInput:
if
self
.
top_logprobs_num
is
None
:
if
self
.
top_logprobs_num
is
None
:
self
.
top_logprobs_num
=
0
self
.
top_logprobs_num
=
0
else
:
else
:
num
=
len
(
self
.
text
)
if
self
.
text
is
not
None
else
len
(
self
.
input_ids
)
parallel_sample_num
=
self
.
sampling_params
.
get
(
"n"
,
1
)
if
parallel_sample_num
!=
1
:
# parallel sampling +1 represents the original prefill stage
num
=
parallel_sample_num
+
1
if
isinstance
(
self
.
text
,
List
):
## suppot batch operation
self
.
batch_size
=
len
(
self
.
text
)
num
=
num
*
len
(
self
.
text
)
else
:
self
.
batch_size
=
1
else
:
## support select operation
num
=
len
(
self
.
text
)
if
self
.
text
is
not
None
else
len
(
self
.
input_ids
)
self
.
batch_size
=
num
if
self
.
image_data
is
None
:
if
self
.
image_data
is
None
:
self
.
image_data
=
[
None
]
*
num
self
.
image_data
=
[
None
]
*
num
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
49c5e0ec
...
@@ -122,125 +122,150 @@ class TokenizerManager:
...
@@ -122,125 +122,150 @@ class TokenizerManager:
obj
.
post_init
()
obj
.
post_init
()
is_single
=
obj
.
is_single
is_single
=
obj
.
is_single
if
is_single
:
rid
=
obj
.
rid
if
obj
.
input_ids
is
None
:
input_ids
=
self
.
tokenizer
.
encode
(
obj
.
text
)
else
:
input_ids
=
obj
.
input_ids
if
len
(
input_ids
)
>=
self
.
context_len
:
if
is_single
:
raise
ValueError
(
async
for
response
in
self
.
_handle_single_request
(
obj
,
request
):
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
yield
response
f
"model's context length (
{
self
.
context_len
}
tokens)."
else
:
)
if
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
)
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
):
if
sampling_params
.
max_new_tokens
!=
0
:
yield
response
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
verify
()
if
isinstance
(
obj
.
image_data
,
list
)
and
len
(
obj
.
image_data
)
>
0
:
async
def
_handle_single_request
(
self
,
obj
,
request
,
index
=
None
,
is_prefill
=
False
):
pixel_values
,
image_hash
,
image_size
=
await
self
.
get_pixel_values
(
if
is_prefill
:
obj
.
image_data
[
0
]
if
isinstance
(
obj
.
text
,
list
):
)
input_text
=
obj
.
text
[
index
]
elif
isinstance
(
obj
.
image_data
,
str
):
rid
=
obj
.
rid
[
index
]
pixel_values
,
image_hash
,
image_size
=
await
self
.
get_pixel_values
(
obj
.
image_data
)
else
:
else
:
pixel_values
,
image_hash
,
image_size
=
None
,
None
,
None
input_text
=
obj
.
text
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
=
obj
.
rid
[
0
]
rid
=
rid
,
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
input_text
=
obj
.
text
,
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
input_ids
=
input_ids
,
sampling_params
.
max_new_tokens
=
0
pixel_values
=
pixel_values
,
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
image_hash
=
image_hash
,
obj
.
image_data
[
0
]
image_size
=
image_size
,
)
sampling_params
=
sampling_params
,
return_logprob
=
obj
.
return_logprob
[
0
]
return_logprob
=
obj
.
return_logprob
,
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
,
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
,
else
:
stream
=
obj
.
stream
,
rid
=
obj
.
rid
if
index
is
None
else
obj
.
rid
[
index
]
input_text
=
obj
.
text
if
index
is
None
else
obj
.
text
[
index
]
input_ids
=
(
self
.
tokenizer
.
encode
(
input_text
)
if
obj
.
input_ids
is
None
else
obj
.
input_ids
)
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
if
index
is
not
None
and
obj
.
input_ids
:
input_ids
=
obj
.
input_ids
[
index
]
event
=
asyncio
.
Event
()
self
.
_validate_input_length
(
input_ids
)
state
=
ReqState
([],
False
,
event
)
sampling_params
=
self
.
_get_sampling_params
(
self
.
rid_to_state
[
rid
]
=
state
obj
.
sampling_params
if
index
is
None
else
obj
.
sampling_params
[
index
]
)
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
if
index
is
None
else
obj
.
image_data
[
index
]
)
return_logprob
=
(
obj
.
return_logprob
if
index
is
None
else
obj
.
return_logprob
[
index
]
)
logprob_start_len
=
(
obj
.
logprob_start_len
if
index
is
None
else
obj
.
logprob_start_len
[
index
]
)
top_logprobs_num
=
(
obj
.
top_logprobs_num
if
index
is
None
else
obj
.
top_logprobs_num
[
index
]
)
while
True
:
tokenized_obj
=
TokenizedGenerateReqInput
(
try
:
rid
,
await
asyncio
.
wait_for
(
event
.
wait
(),
timeout
=
4
)
input_text
,
except
asyncio
.
TimeoutError
:
input_ids
,
if
request
is
not
None
and
await
request
.
is_disconnected
():
pixel_values
,
self
.
abort_request
(
rid
)
image_hash
,
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
image_size
,
sampling_params
,
return_logprob
,
logprob_start_len
,
top_logprobs_num
,
obj
.
stream
,
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
if
is_prefill
==
False
:
async
for
response
in
self
.
_wait_for_response
(
event
,
state
,
obj
,
rid
,
request
):
yield
response
else
:
await
self
.
_wait_for_prefill_response
(
event
,
state
,
obj
,
request
,
rid
)
yield
input_ids
async
def
_handle_batch_request
(
self
,
obj
,
request
):
batch_size
=
obj
.
batch_size
parallel_sample_num
=
obj
.
sampling_params
[
0
].
get
(
"n"
,
1
)
if
parallel_sample_num
!=
1
:
## send prefill requests
parallel_sample_num
+=
1
input_id_result
=
[]
if
obj
.
input_ids
is
None
else
None
for
i
in
range
(
batch_size
):
async
for
input_id
in
self
.
_handle_single_request
(
obj
,
request
,
index
=
i
,
is_prefill
=
True
):
if
input_id_result
is
not
None
:
input_id_result
.
append
(
input_id
)
pass
if
len
(
input_id_result
)
>
1
and
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
elif
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
[
0
]
# First send out all requests
for
i
in
range
(
batch_size
):
for
j
in
range
(
parallel_sample_num
):
if
j
==
0
and
parallel_sample_num
!=
1
:
continue
continue
index
=
i
*
parallel_sample_num
+
j
out
=
self
.
convert_logprob_style
(
if
parallel_sample_num
!=
1
:
state
.
out_list
[
-
1
],
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
obj
.
return_logprob
,
index
+=
batch_size
-
1
-
i
obj
.
top_logprobs_num
,
rid
=
obj
.
rid
[
index
]
obj
.
return_text_in_logprobs
,
if
parallel_sample_num
==
1
:
## select operation
if
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
[
i
]
input_ids
=
self
.
tokenizer
.
encode
(
obj
.
text
[
i
])
else
:
input_text
=
None
input_ids
=
obj
.
input_ids
[
i
]
else
:
if
batch_size
==
1
:
input_text
=
obj
.
text
input_ids
=
obj
.
input_ids
else
:
input_text
=
obj
.
text
[
i
]
input_ids
=
obj
.
input_ids
[
i
]
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
[
index
]
)
)
if
self
.
server_args
.
log_requests
and
state
.
finished
:
logger
.
info
(
f
"in=
{
obj
.
text
}
, out=
{
out
}
"
)
state
.
out_list
=
[]
if
state
.
finished
:
del
self
.
rid_to_state
[
rid
]
yield
out
break
event
.
clear
()
yield
out
else
:
if
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
if
obj
.
input_ids
is
None
:
bs
=
len
(
obj
.
text
)
else
:
bs
=
len
(
obj
.
input_ids
)
for
i
in
range
(
bs
):
rid
=
obj
.
rid
[
i
]
if
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
[
i
]
input_ids
=
self
.
tokenizer
.
encode
(
obj
.
text
[
i
])
else
:
input_text
=
None
input_ids
=
obj
.
input_ids
[
i
]
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
i
])
if
sampling_params
.
max_new_tokens
!=
0
:
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
verify
()
if
obj
.
image_data
[
i
]
is
None
:
pixel_values
,
image_hash
,
image_size
=
None
,
None
,
None
else
:
pixel_values
,
image_hash
,
image_size
=
await
self
.
get_pixel_values
(
obj
.
image_data
[
i
]
)
tokenized_obj
=
TokenizedGenerateReqInput
(
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
=
rid
,
rid
,
input_text
=
input_text
,
input_text
,
input_ids
=
input_ids
,
input_ids
,
pixel_values
=
pixel_values
,
pixel_values
,
image_hash
=
image_hash
,
image_hash
,
image_size
=
image_size
,
image_size
,
sampling_params
=
sampling_params
,
sampling_params
,
return_logprob
=
obj
.
return_logprob
[
i
],
obj
.
return_logprob
[
i
ndex
],
logprob_start_len
=
obj
.
logprob_start_len
[
i
],
obj
.
logprob_start_len
[
i
ndex
],
top_logprobs_num
=
obj
.
top_logprobs_num
[
i
],
obj
.
top_logprobs_num
[
i
ndex
],
stream
=
obj
.
stream
,
obj
.
stream
,
)
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
...
@@ -248,9 +273,16 @@ class TokenizerManager:
...
@@ -248,9 +273,16 @@ class TokenizerManager:
state
=
ReqState
([],
False
,
event
)
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
self
.
rid_to_state
[
rid
]
=
state
output_list
=
[]
# Then wait for all responses
for
i
in
range
(
bs
):
output_list
=
[]
rid
=
obj
.
rid
[
i
]
for
i
in
range
(
batch_size
):
for
j
in
range
(
parallel_sample_num
):
if
j
==
0
and
parallel_sample_num
!=
1
:
continue
index
=
i
*
parallel_sample_num
+
j
if
parallel_sample_num
!=
1
:
index
+=
batch_size
-
1
-
i
rid
=
obj
.
rid
[
index
]
state
=
self
.
rid_to_state
[
rid
]
state
=
self
.
rid_to_state
[
rid
]
while
True
:
while
True
:
...
@@ -263,19 +295,86 @@ class TokenizerManager:
...
@@ -263,19 +295,86 @@ class TokenizerManager:
self
.
abort_request
(
rid
)
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
continue
output_list
.
append
(
output_list
.
append
(
self
.
convert_logprob_style
(
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
state
.
out_list
[
-
1
],
obj
.
return_logprob
[
i
],
obj
.
return_logprob
[
i
ndex
],
obj
.
top_logprobs_num
[
i
],
obj
.
top_logprobs_num
[
i
ndex
],
obj
.
return_text_in_logprobs
,
obj
.
return_text_in_logprobs
,
)
)
)
)
assert
state
.
finished
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
del
self
.
rid_to_state
[
rid
]
yield
output_list
yield
output_list
def
_validate_input_length
(
self
,
input_ids
):
if
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
f
"model's context length (
{
self
.
context_len
}
tokens)."
)
def
_get_sampling_params
(
self
,
sampling_params_data
,
max_new_tokens
=
None
):
sampling_params
=
SamplingParams
(
**
sampling_params_data
)
if
max_new_tokens
is
not
None
:
sampling_params
.
max_new_tokens
=
max_new_tokens
if
sampling_params
.
max_new_tokens
!=
0
:
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
verify
()
return
sampling_params
async
def
_get_pixel_values
(
self
,
image_data
):
if
isinstance
(
image_data
,
list
)
and
len
(
image_data
)
>
0
:
return
await
self
.
get_pixel_values
(
image_data
[
0
])
elif
isinstance
(
image_data
,
str
):
return
await
self
.
get_pixel_values
(
image_data
)
else
:
return
None
,
None
,
None
async
def
_wait_for_response
(
self
,
event
,
state
,
obj
,
rid
,
request
):
while
True
:
try
:
await
asyncio
.
wait_for
(
event
.
wait
(),
timeout
=
4
)
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
out
=
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
obj
.
return_logprob
,
obj
.
top_logprobs_num
,
obj
.
return_text_in_logprobs
,
)
if
self
.
server_args
.
log_requests
and
state
.
finished
:
logger
.
info
(
f
"in=
{
obj
.
text
}
, out=
{
out
}
"
)
state
.
out_list
=
[]
if
state
.
finished
:
del
self
.
rid_to_state
[
rid
]
yield
out
break
event
.
clear
()
yield
out
async
def
_wait_for_prefill_response
(
self
,
event
,
state
,
obj
,
request
,
rid
):
while
True
:
try
:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
4
)
break
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
for
rid
in
obj
.
rid
:
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
def
flush_cache
(
self
):
def
flush_cache
(
self
):
req
=
FlushCacheReq
()
req
=
FlushCacheReq
()
...
...
python/sglang/srt/openai_api_adapter.py
View file @
49c5e0ec
...
@@ -95,9 +95,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
...
@@ -95,9 +95,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
request_json
=
await
raw_request
.
json
()
request_json
=
await
raw_request
.
json
()
request
=
CompletionRequest
(
**
request_json
)
request
=
CompletionRequest
(
**
request_json
)
if
request
.
n
!=
1
:
return
create_error_response
(
"n != 1 is not supported"
)
adapted_request
=
GenerateReqInput
(
adapted_request
=
GenerateReqInput
(
text
=
request
.
prompt
,
text
=
request
.
prompt
,
sampling_params
=
{
sampling_params
=
{
...
@@ -108,6 +105,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
...
@@ -108,6 +105,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
"presence_penalty"
:
request
.
presence_penalty
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"regex"
:
request
.
regex
,
"regex"
:
request
.
regex
,
"n"
:
request
.
n
,
},
},
return_logprob
=
request
.
logprobs
is
not
None
and
request
.
logprobs
>
0
,
return_logprob
=
request
.
logprobs
is
not
None
and
request
.
logprobs
>
0
,
top_logprobs_num
=
request
.
logprobs
if
request
.
logprobs
is
not
None
else
0
,
top_logprobs_num
=
request
.
logprobs
if
request
.
logprobs
is
not
None
else
0
,
...
@@ -202,46 +200,56 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
...
@@ -202,46 +200,56 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
return
create_error_response
(
str
(
e
))
ret
=
ret
[
0
]
if
isinstance
(
ret
,
list
)
else
ret
if
not
isinstance
(
ret
,
list
)
:
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
ret
=
[
ret
]
c
ompletion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
c
hoices
=
[
]
text
=
ret
[
"text"
]
if
request
.
echo
:
for
idx
,
ret_item
in
enumerate
(
ret
)
:
text
=
re
quest
.
prompt
+
text
text
=
re
t_item
[
"
text
"
]
if
request
.
logprobs
:
if
request
.
echo
:
if
request
.
echo
:
prefill_token_logprobs
=
ret
[
"meta_info"
][
"prefill_token_logprobs"
]
text
=
request
.
prompt
+
text
prefill_top_logprobs
=
ret
[
"meta_info"
][
"prefill_top_logprobs"
]
if
request
.
logprobs
:
if
request
.
echo
:
prefill_token_logprobs
=
ret_item
[
"meta_info"
][
"prefill_token_logprobs"
]
prefill_top_logprobs
=
ret_item
[
"meta_info"
][
"prefill_top_logprobs"
]
else
:
prefill_token_logprobs
=
None
prefill_top_logprobs
=
None
logprobs
=
to_openai_style_logprobs
(
prefill_token_logprobs
=
prefill_token_logprobs
,
prefill_top_logprobs
=
prefill_top_logprobs
,
decode_token_logprobs
=
ret_item
[
"meta_info"
][
"decode_token_logprobs"
],
decode_top_logprobs
=
ret_item
[
"meta_info"
][
"decode_top_logprobs"
],
)
else
:
else
:
prefill_token_logprobs
=
None
logprobs
=
None
prefill_top_logprobs
=
None
choice_data
=
CompletionResponseChoice
(
logprobs
=
to_openai_style_logprobs
(
index
=
idx
,
prefill_token_logprobs
=
prefill_token_logprobs
,
text
=
text
,
prefill_top_logprobs
=
prefill_top_logprobs
,
logprobs
=
logprobs
,
decode_token_logprobs
=
ret
[
"meta_info"
][
"decode_token_logprobs"
],
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
],
decode_top_logprobs
=
ret
[
"meta_info"
][
"decode_top_logprobs"
],
)
)
else
:
logprobs
=
None
choice_data
=
CompletionResponseChoice
(
choices
.
append
(
choice_data
)
index
=
0
,
text
=
text
,
logprobs
=
logprobs
,
finish_reason
=
ret
[
"meta_info"
][
"finish_reason"
],
)
response
=
CompletionResponse
(
response
=
CompletionResponse
(
id
=
ret
[
"meta_info"
][
"id"
],
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
model
=
request
.
model
,
choices
=
[
choice
_data
]
,
choices
=
choice
s
,
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
ret
[
0
][
"meta_info"
][
"prompt_tokens"
],
completion_tokens
=
completion_tokens
,
completion_tokens
=
sum
(
total_tokens
=
prompt_tokens
+
completion_tokens
,
item
[
"meta_info"
][
"completion_tokens"
]
for
item
in
ret
),
total_tokens
=
ret
[
0
][
"meta_info"
][
"prompt_tokens"
]
+
sum
(
item
[
"meta_info"
][
"completion_tokens"
]
for
item
in
ret
),
),
),
)
)
return
response
return
response
...
@@ -249,9 +257,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -249,9 +257,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json
=
await
raw_request
.
json
()
request_json
=
await
raw_request
.
json
()
request
=
ChatCompletionRequest
(
**
request_json
)
request
=
ChatCompletionRequest
(
**
request_json
)
if
request
.
n
!=
1
:
return
create_error_response
(
"n != 1 is not supported"
)
# Prep the data needed for the underlying GenerateReqInput:
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - stop: Custom stop tokens.
...
@@ -292,6 +297,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -292,6 +297,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
"presence_penalty"
:
request
.
presence_penalty
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"regex"
:
request
.
regex
,
"regex"
:
request
.
regex
,
"n"
:
request
.
n
,
},
},
stream
=
request
.
stream
,
stream
=
request
.
stream
,
)
)
...
@@ -354,23 +360,37 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -354,23 +360,37 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
return
create_error_response
(
str
(
e
))
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
if
not
isinstance
(
ret
,
list
):
completion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
ret
=
[
ret
]
choice_data
=
ChatCompletionResponseChoice
(
choices
=
[]
index
=
0
,
total_prompt_tokens
=
0
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
ret
[
"text"
]),
total_completion_tokens
=
0
finish_reason
=
ret
[
"meta_info"
][
"finish_reason"
],
)
for
idx
,
ret_item
in
enumerate
(
ret
):
prompt_tokens
=
ret_item
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret_item
[
"meta_info"
][
"completion_tokens"
]
choice_data
=
ChatCompletionResponseChoice
(
index
=
idx
,
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
ret_item
[
"text"
]),
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
],
)
choices
.
append
(
choice_data
)
total_prompt_tokens
=
prompt_tokens
total_completion_tokens
+=
completion_tokens
response
=
ChatCompletionResponse
(
response
=
ChatCompletionResponse
(
id
=
ret
[
"meta_info"
][
"id"
],
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
model
=
request
.
model
,
choices
=
[
choice
_data
]
,
choices
=
choice
s
,
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
total_
prompt_tokens
,
completion_tokens
=
completion_tokens
,
completion_tokens
=
total_
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
total_tokens
=
total_
prompt_tokens
+
total_
completion_tokens
,
),
),
)
)
return
response
return
response
...
...
python/sglang/srt/sampling_params.py
View file @
49c5e0ec
...
@@ -20,6 +20,7 @@ class SamplingParams:
...
@@ -20,6 +20,7 @@ class SamplingParams:
spaces_between_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
dtype
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
str
]
=
None
,
regex
:
Optional
[
str
]
=
None
,
regex
:
Optional
[
str
]
=
None
,
n
:
int
=
1
,
)
->
None
:
)
->
None
:
self
.
temperature
=
temperature
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_p
=
top_p
...
@@ -33,6 +34,7 @@ class SamplingParams:
...
@@ -33,6 +34,7 @@ class SamplingParams:
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
regex
=
regex
self
.
regex
=
regex
self
.
n
=
n
# Process some special cases
# Process some special cases
if
self
.
temperature
<
_SAMPLING_EPS
:
if
self
.
temperature
<
_SAMPLING_EPS
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment