Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
b912de11
"vscode:/vscode.git/clone" did not exist on "ad2450b129de39a256cb15f14708b10bcb5466b5"
Unverified
Commit
b912de11
authored
Sep 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 12, 2024
Browse files
Make stop reason a dict instead of str (#1407)
parent
eb02c161
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
43 deletions
+59
-43
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+26
-15
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+5
-1
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+28
-27
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
b912de11
...
@@ -56,7 +56,7 @@ class BaseFinishReason:
...
@@ -56,7 +56,7 @@ class BaseFinishReason:
def
__init__
(
self
,
is_error
:
bool
=
False
):
def
__init__
(
self
,
is_error
:
bool
=
False
):
self
.
is_error
=
is_error
self
.
is_error
=
is_error
def
__str__
(
self
):
def
to_json
(
self
):
raise
NotImplementedError
(
"Subclasses must implement this method"
)
raise
NotImplementedError
(
"Subclasses must implement this method"
)
...
@@ -65,34 +65,45 @@ class FINISH_MATCHED_TOKEN(BaseFinishReason):
...
@@ -65,34 +65,45 @@ class FINISH_MATCHED_TOKEN(BaseFinishReason):
super
().
__init__
()
super
().
__init__
()
self
.
matched
=
matched
self
.
matched
=
matched
def
__str__
(
self
)
->
str
:
def
to_json
(
self
):
return
f
"FINISH_MATCHED_TOKEN:
{
self
.
matched
}
"
return
{
"type"
:
"stop"
,
# to match OpenAI API's return value
"matched"
:
self
.
matched
,
}
class
FINISH_
LENGTH
(
BaseFinishReason
):
class
FINISH_
MATCHED_STR
(
BaseFinishReason
):
def
__init__
(
self
,
length
:
int
):
def
__init__
(
self
,
matched
:
str
):
super
().
__init__
()
super
().
__init__
()
self
.
length
=
length
self
.
matched
=
matched
def
__str__
(
self
)
->
str
:
def
to_json
(
self
):
return
f
"FINISH_LENGTH:
{
self
.
length
}
"
return
{
"type"
:
"stop"
,
# to match OpenAI API's return value
"matched"
:
self
.
matched
,
}
class
FINISH_
MATCHED_STR
(
BaseFinishReason
):
class
FINISH_
LENGTH
(
BaseFinishReason
):
def
__init__
(
self
,
matched
:
str
):
def
__init__
(
self
,
length
:
int
):
super
().
__init__
()
super
().
__init__
()
self
.
matched
=
matched
self
.
length
=
length
def
__str__
(
self
)
->
str
:
def
to_json
(
self
):
return
f
"FINISH_MATCHED_STR:
{
self
.
matched
}
"
return
{
"type"
:
"length"
,
# to match OpenAI API's return value
"length"
:
self
.
length
,
}
class
FINISH_ABORT
(
BaseFinishReason
):
class
FINISH_ABORT
(
BaseFinishReason
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
(
is_error
=
True
)
super
().
__init__
(
is_error
=
True
)
def
__str__
(
self
)
->
str
:
def
to_json
(
self
):
return
"FINISH_ABORT"
return
{
"type"
:
"abort"
,
}
class
Req
:
class
Req
:
...
...
python/sglang/srt/managers/tp_worker.py
View file @
b912de11
...
@@ -813,7 +813,11 @@ class ModelTpServer:
...
@@ -813,7 +813,11 @@ class ModelTpServer:
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
str
(
req
.
finished_reason
),
"finish_reason"
:
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
is
not
None
else
None
),
}
}
if
req
.
return_logprob
:
if
req
.
return_logprob
:
(
(
...
...
python/sglang/srt/openai_api/adapter.py
View file @
b912de11
...
@@ -95,19 +95,6 @@ file_id_storage: Dict[str, str] = {}
...
@@ -95,19 +95,6 @@ file_id_storage: Dict[str, str] = {}
storage_dir
=
None
storage_dir
=
None
def
format_finish_reason
(
finish_reason
)
->
Optional
[
str
]:
if
finish_reason
.
startswith
(
"None"
):
return
None
elif
finish_reason
.
startswith
(
"FINISH_MATCHED"
):
return
"stop"
elif
finish_reason
.
startswith
(
"FINISH_LENGTH"
):
return
"length"
elif
finish_reason
.
startswith
(
"FINISH_ABORT"
):
return
"abort"
else
:
return
"unknown"
def
create_error_response
(
def
create_error_response
(
message
:
str
,
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
err_type
:
str
=
"BadRequestError"
,
...
@@ -618,8 +605,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
...
@@ -618,8 +605,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
"index"
:
0
,
"index"
:
0
,
"text"
:
text
,
"text"
:
text
,
"logprobs"
:
logprobs
,
"logprobs"
:
logprobs
,
"finish_reason"
:
format_finish_reason
(
"finish_reason"
:
(
ret_item
[
"meta_info"
][
"finish_reason"
]
ret_item
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
ret_item
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
}
}
else
:
else
:
...
@@ -627,8 +616,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
...
@@ -627,8 +616,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
index
=
idx
,
index
=
idx
,
text
=
text
,
text
=
text
,
logprobs
=
logprobs
,
logprobs
=
logprobs
,
finish_reason
=
format_finish_reason
(
finish_reason
=
(
ret_item
[
"meta_info"
][
"finish_reason"
]
ret_item
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
ret_item
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
)
)
...
@@ -762,8 +753,10 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
...
@@ -762,8 +753,10 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
index
=
index
,
index
=
index
,
text
=
delta
,
text
=
delta
,
logprobs
=
logprobs
,
logprobs
=
logprobs
,
finish_reason
=
format_finish_reason
(
finish_reason
=
(
content
[
"meta_info"
][
"finish_reason"
]
content
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
content
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
)
)
chunk
=
CompletionStreamResponse
(
chunk
=
CompletionStreamResponse
(
...
@@ -999,8 +992,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
...
@@ -999,8 +992,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
"index"
:
0
,
"index"
:
0
,
"message"
:
{
"role"
:
"assistant"
,
"content"
:
ret_item
[
"text"
]},
"message"
:
{
"role"
:
"assistant"
,
"content"
:
ret_item
[
"text"
]},
"logprobs"
:
choice_logprobs
,
"logprobs"
:
choice_logprobs
,
"finish_reason"
:
format_finish_reason
(
"finish_reason"
:
(
ret_item
[
"meta_info"
][
"finish_reason"
]
ret_item
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
ret_item
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
}
}
else
:
else
:
...
@@ -1008,8 +1003,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
...
@@ -1008,8 +1003,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
index
=
idx
,
index
=
idx
,
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
ret_item
[
"text"
]),
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
ret_item
[
"text"
]),
logprobs
=
choice_logprobs
,
logprobs
=
choice_logprobs
,
finish_reason
=
format_finish_reason
(
finish_reason
=
(
ret_item
[
"meta_info"
][
"finish_reason"
]
ret_item
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
ret_item
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
)
)
...
@@ -1134,8 +1131,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -1134,8 +1131,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
index
=
index
,
delta
=
DeltaMessage
(
role
=
"assistant"
),
delta
=
DeltaMessage
(
role
=
"assistant"
),
finish_reason
=
format_finish_reason
(
finish_reason
=
(
content
[
"meta_info"
][
"finish_reason"
]
content
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
content
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
logprobs
=
choice_logprobs
,
logprobs
=
choice_logprobs
,
)
)
...
@@ -1152,8 +1151,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -1152,8 +1151,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
index
=
index
,
delta
=
DeltaMessage
(
content
=
delta
),
delta
=
DeltaMessage
(
content
=
delta
),
finish_reason
=
format_finish_reason
(
finish_reason
=
(
content
[
"meta_info"
][
"finish_reason"
]
content
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
content
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
logprobs
=
choice_logprobs
,
logprobs
=
choice_logprobs
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment