Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b912de11
Unverified
Commit
b912de11
authored
Sep 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 12, 2024
Browse files
Make stop reason a dict instead of str (#1407)
parent
eb02c161
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
43 deletions
+59
-43
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+26
-15
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+5
-1
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+28
-27
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
b912de11
...
@@ -56,7 +56,7 @@ class BaseFinishReason:
...
@@ -56,7 +56,7 @@ class BaseFinishReason:
def
__init__
(
self
,
is_error
:
bool
=
False
):
def
__init__
(
self
,
is_error
:
bool
=
False
):
self
.
is_error
=
is_error
self
.
is_error
=
is_error
def
__str__
(
self
):
def
to_json
(
self
):
raise
NotImplementedError
(
"Subclasses must implement this method"
)
raise
NotImplementedError
(
"Subclasses must implement this method"
)
...
@@ -65,34 +65,45 @@ class FINISH_MATCHED_TOKEN(BaseFinishReason):
...
@@ -65,34 +65,45 @@ class FINISH_MATCHED_TOKEN(BaseFinishReason):
super
().
__init__
()
super
().
__init__
()
self
.
matched
=
matched
self
.
matched
=
matched
def
__str__
(
self
)
->
str
:
def
to_json
(
self
):
return
f
"FINISH_MATCHED_TOKEN:
{
self
.
matched
}
"
return
{
"type"
:
"stop"
,
# to match OpenAI API's return value
"matched"
:
self
.
matched
,
}
class
FINISH_
LENGTH
(
BaseFinishReason
):
class
FINISH_
MATCHED_STR
(
BaseFinishReason
):
def
__init__
(
self
,
length
:
int
):
def
__init__
(
self
,
matched
:
str
):
super
().
__init__
()
super
().
__init__
()
self
.
length
=
length
self
.
matched
=
matched
def
__str__
(
self
)
->
str
:
def
to_json
(
self
):
return
f
"FINISH_LENGTH:
{
self
.
length
}
"
return
{
"type"
:
"stop"
,
# to match OpenAI API's return value
"matched"
:
self
.
matched
,
}
class
FINISH_
MATCHED_STR
(
BaseFinishReason
):
class
FINISH_
LENGTH
(
BaseFinishReason
):
def
__init__
(
self
,
matched
:
str
):
def
__init__
(
self
,
length
:
int
):
super
().
__init__
()
super
().
__init__
()
self
.
matched
=
matched
self
.
length
=
length
def
__str__
(
self
)
->
str
:
def
to_json
(
self
):
return
f
"FINISH_MATCHED_STR:
{
self
.
matched
}
"
return
{
"type"
:
"length"
,
# to match OpenAI API's return value
"length"
:
self
.
length
,
}
class
FINISH_ABORT
(
BaseFinishReason
):
class
FINISH_ABORT
(
BaseFinishReason
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
(
is_error
=
True
)
super
().
__init__
(
is_error
=
True
)
def
__str__
(
self
)
->
str
:
def
to_json
(
self
):
return
"FINISH_ABORT"
return
{
"type"
:
"abort"
,
}
class
Req
:
class
Req
:
...
...
python/sglang/srt/managers/tp_worker.py
View file @
b912de11
...
@@ -813,7 +813,11 @@ class ModelTpServer:
...
@@ -813,7 +813,11 @@ class ModelTpServer:
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
str
(
req
.
finished_reason
),
"finish_reason"
:
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
is
not
None
else
None
),
}
}
if
req
.
return_logprob
:
if
req
.
return_logprob
:
(
(
...
...
python/sglang/srt/openai_api/adapter.py
View file @
b912de11
...
@@ -95,19 +95,6 @@ file_id_storage: Dict[str, str] = {}
...
@@ -95,19 +95,6 @@ file_id_storage: Dict[str, str] = {}
storage_dir
=
None
storage_dir
=
None
def
format_finish_reason
(
finish_reason
)
->
Optional
[
str
]:
if
finish_reason
.
startswith
(
"None"
):
return
None
elif
finish_reason
.
startswith
(
"FINISH_MATCHED"
):
return
"stop"
elif
finish_reason
.
startswith
(
"FINISH_LENGTH"
):
return
"length"
elif
finish_reason
.
startswith
(
"FINISH_ABORT"
):
return
"abort"
else
:
return
"unknown"
def
create_error_response
(
def
create_error_response
(
message
:
str
,
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
err_type
:
str
=
"BadRequestError"
,
...
@@ -618,8 +605,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
...
@@ -618,8 +605,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
"index"
:
0
,
"index"
:
0
,
"text"
:
text
,
"text"
:
text
,
"logprobs"
:
logprobs
,
"logprobs"
:
logprobs
,
"finish_reason"
:
format_finish_reason
(
"finish_reason"
:
(
ret_item
[
"meta_info"
][
"finish_reason"
]
ret_item
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
ret_item
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
}
}
else
:
else
:
...
@@ -627,8 +616,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
...
@@ -627,8 +616,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
index
=
idx
,
index
=
idx
,
text
=
text
,
text
=
text
,
logprobs
=
logprobs
,
logprobs
=
logprobs
,
finish_reason
=
format_finish_reason
(
finish_reason
=
(
ret_item
[
"meta_info"
][
"finish_reason"
]
ret_item
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
ret_item
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
)
)
...
@@ -762,8 +753,10 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
...
@@ -762,8 +753,10 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
index
=
index
,
index
=
index
,
text
=
delta
,
text
=
delta
,
logprobs
=
logprobs
,
logprobs
=
logprobs
,
finish_reason
=
format_finish_reason
(
finish_reason
=
(
content
[
"meta_info"
][
"finish_reason"
]
content
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
content
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
)
)
chunk
=
CompletionStreamResponse
(
chunk
=
CompletionStreamResponse
(
...
@@ -999,8 +992,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
...
@@ -999,8 +992,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
"index"
:
0
,
"index"
:
0
,
"message"
:
{
"role"
:
"assistant"
,
"content"
:
ret_item
[
"text"
]},
"message"
:
{
"role"
:
"assistant"
,
"content"
:
ret_item
[
"text"
]},
"logprobs"
:
choice_logprobs
,
"logprobs"
:
choice_logprobs
,
"finish_reason"
:
format_finish_reason
(
"finish_reason"
:
(
ret_item
[
"meta_info"
][
"finish_reason"
]
ret_item
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
ret_item
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
}
}
else
:
else
:
...
@@ -1008,8 +1003,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
...
@@ -1008,8 +1003,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
index
=
idx
,
index
=
idx
,
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
ret_item
[
"text"
]),
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
ret_item
[
"text"
]),
logprobs
=
choice_logprobs
,
logprobs
=
choice_logprobs
,
finish_reason
=
format_finish_reason
(
finish_reason
=
(
ret_item
[
"meta_info"
][
"finish_reason"
]
ret_item
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
ret_item
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
)
)
...
@@ -1134,8 +1131,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -1134,8 +1131,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
index
=
index
,
delta
=
DeltaMessage
(
role
=
"assistant"
),
delta
=
DeltaMessage
(
role
=
"assistant"
),
finish_reason
=
format_finish_reason
(
finish_reason
=
(
content
[
"meta_info"
][
"finish_reason"
]
content
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
content
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
logprobs
=
choice_logprobs
,
logprobs
=
choice_logprobs
,
)
)
...
@@ -1152,8 +1151,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -1152,8 +1151,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
index
=
index
,
delta
=
DeltaMessage
(
content
=
delta
),
delta
=
DeltaMessage
(
content
=
delta
),
finish_reason
=
format_finish_reason
(
finish_reason
=
(
content
[
"meta_info"
][
"finish_reason"
]
content
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
content
[
"meta_info"
][
"finish_reason"
]
else
""
),
),
logprobs
=
choice_logprobs
,
logprobs
=
choice_logprobs
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment