Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fd7926e4
Unverified
Commit
fd7926e4
authored
Aug 05, 2024
by
yichuan~
Committed by
GitHub
Aug 05, 2024
Browse files
Fix prompt len in parallel sampling (#928)
parent
399cad91
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
15 deletions
+11
-15
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+11
-10
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+0
-5
No files found.
python/sglang/srt/openai_api/adapter.py
View file @
fd7926e4
...
@@ -500,7 +500,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
...
@@ -500,7 +500,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
responses
.
append
(
response
)
responses
.
append
(
response
)
return
responses
return
responses
else
:
else
:
prompt_tokens
=
sum
(
item
[
"meta_info"
][
"prompt_tokens"
]
for
item
in
ret
)
prompt_tokens
=
sum
(
ret
[
i
][
"meta_info"
][
"prompt_tokens"
]
for
i
in
range
(
0
,
len
(
ret
),
request
.
n
)
)
completion_tokens
=
sum
(
item
[
"meta_info"
][
"completion_tokens"
]
for
item
in
ret
)
completion_tokens
=
sum
(
item
[
"meta_info"
][
"completion_tokens"
]
for
item
in
ret
)
response
=
CompletionResponse
(
response
=
CompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
id
=
ret
[
0
][
"meta_info"
][
"id"
],
...
@@ -707,8 +709,6 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
...
@@ -707,8 +709,6 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
def
v1_chat_generate_response
(
request
,
ret
,
to_file
=
False
):
def
v1_chat_generate_response
(
request
,
ret
,
to_file
=
False
):
choices
=
[]
choices
=
[]
total_prompt_tokens
=
0
total_completion_tokens
=
0
for
idx
,
ret_item
in
enumerate
(
ret
):
for
idx
,
ret_item
in
enumerate
(
ret
):
logprobs
=
False
logprobs
=
False
...
@@ -747,8 +747,6 @@ def v1_chat_generate_response(request, ret, to_file=False):
...
@@ -747,8 +747,6 @@ def v1_chat_generate_response(request, ret, to_file=False):
choice_logprobs
=
ChoiceLogprobs
(
content
=
token_logprobs
)
choice_logprobs
=
ChoiceLogprobs
(
content
=
token_logprobs
)
else
:
else
:
choice_logprobs
=
None
choice_logprobs
=
None
prompt_tokens
=
ret_item
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret_item
[
"meta_info"
][
"completion_tokens"
]
if
to_file
:
if
to_file
:
# to make the choice data json serializable
# to make the choice data json serializable
...
@@ -767,8 +765,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
...
@@ -767,8 +765,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
)
)
choices
.
append
(
choice_data
)
choices
.
append
(
choice_data
)
total_prompt_tokens
+=
prompt_tokens
total_completion_tokens
+=
completion_tokens
if
to_file
:
if
to_file
:
responses
=
[]
responses
=
[]
...
@@ -795,14 +792,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
...
@@ -795,14 +792,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
responses
.
append
(
response
)
responses
.
append
(
response
)
return
responses
return
responses
else
:
else
:
prompt_tokens
=
sum
(
ret
[
i
][
"meta_info"
][
"prompt_tokens"
]
for
i
in
range
(
0
,
len
(
ret
),
request
.
n
)
)
completion_tokens
=
sum
(
item
[
"meta_info"
][
"completion_tokens"
]
for
item
in
ret
)
response
=
ChatCompletionResponse
(
response
=
ChatCompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
model
=
request
.
model
,
choices
=
choices
,
choices
=
choices
,
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
total_
prompt_tokens
,
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
total_
completion_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
total_
prompt_tokens
+
total_
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
),
),
)
)
return
response
return
response
...
...
test/srt/test_openai_server.py
View file @
fd7926e4
...
@@ -45,11 +45,6 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -45,11 +45,6 @@ class TestOpenAIServer(unittest.TestCase):
prompt_arg
=
prompt_input
prompt_arg
=
prompt_input
num_choices
=
1
num_choices
=
1
if
parallel_sample_num
:
# FIXME: This is wrong. We should not count the prompt tokens multiple times for
# parallel sampling.
num_prompt_tokens
*=
parallel_sample_num
response
=
client
.
completions
.
create
(
response
=
client
.
completions
.
create
(
model
=
self
.
model
,
model
=
self
.
model
,
prompt
=
prompt_arg
,
prompt
=
prompt_arg
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment