Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2422de51
Unverified
Commit
2422de51
authored
Oct 05, 2024
by
Byron Hsu
Committed by
GitHub
Oct 05, 2024
Browse files
Support min_tokens in sgl.gen (#1573)
parent
521f862d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
82 additions
and
0 deletions
+82
-0
examples/frontend_language/usage/sgl_gen_min_tokens.py
examples/frontend_language/usage/sgl_gen_min_tokens.py
+35
-0
python/sglang/api.py
python/sglang/api.py
+4
-0
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+1
-0
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+5
-0
python/sglang/test/test_programs.py
python/sglang/test/test_programs.py
+33
-0
test/lang/test_srt_backend.py
test/lang/test_srt_backend.py
+4
-0
No files found.
examples/frontend_language/usage/sgl_gen_min_tokens.py
0 → 100644
View file @
2422de51
"""
This example demonstrates how to use `min_tokens` to enforce sgl.gen to generate a longer sequence
Usage:
python3 sgl_gen_min_tokens.py
"""
import
sglang
as
sgl
@
sgl
.
function
def
long_answer
(
s
):
s
+=
sgl
.
user
(
"What is the capital of the United States?"
)
s
+=
sgl
.
assistant
(
sgl
.
gen
(
"answer"
,
min_tokens
=
64
,
max_tokens
=
128
))
@
sgl
.
function
def
short_answer
(
s
):
s
+=
sgl
.
user
(
"What is the capital of the United States?"
)
s
+=
sgl
.
assistant
(
sgl
.
gen
(
"answer"
))
if
__name__
==
"__main__"
:
runtime
=
sgl
.
Runtime
(
model_path
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
)
sgl
.
set_default_backend
(
runtime
)
state
=
long_answer
.
run
()
print
(
"="
*
20
)
print
(
"Longer Answer"
,
state
[
"answer"
])
state
=
short_answer
.
run
()
print
(
"="
*
20
)
print
(
"Short Answer"
,
state
[
"answer"
])
runtime
.
shutdown
()
python/sglang/api.py
View file @
2422de51
...
@@ -69,6 +69,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
...
@@ -69,6 +69,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
def
gen
(
def
gen
(
name
:
Optional
[
str
]
=
None
,
name
:
Optional
[
str
]
=
None
,
max_tokens
:
Optional
[
int
]
=
None
,
max_tokens
:
Optional
[
int
]
=
None
,
min_tokens
:
Optional
[
int
]
=
None
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
temperature
:
Optional
[
float
]
=
None
,
temperature
:
Optional
[
float
]
=
None
,
...
@@ -108,6 +109,7 @@ def gen(
...
@@ -108,6 +109,7 @@ def gen(
return
SglGen
(
return
SglGen
(
name
,
name
,
max_tokens
,
max_tokens
,
min_tokens
,
stop
,
stop
,
stop_token_ids
,
stop_token_ids
,
temperature
,
temperature
,
...
@@ -147,6 +149,7 @@ def gen_int(
...
@@ -147,6 +149,7 @@ def gen_int(
return
SglGen
(
return
SglGen
(
name
,
name
,
max_tokens
,
max_tokens
,
None
,
stop
,
stop
,
stop_token_ids
,
stop_token_ids
,
temperature
,
temperature
,
...
@@ -185,6 +188,7 @@ def gen_string(
...
@@ -185,6 +188,7 @@ def gen_string(
return
SglGen
(
return
SglGen
(
name
,
name
,
max_tokens
,
max_tokens
,
None
,
stop
,
stop
,
stop_token_ids
,
stop_token_ids
,
temperature
,
temperature
,
...
...
python/sglang/lang/interpreter.py
View file @
2422de51
...
@@ -668,6 +668,7 @@ class StreamExecutor:
...
@@ -668,6 +668,7 @@ class StreamExecutor:
for
item
in
[
for
item
in
[
"max_new_tokens"
,
"max_new_tokens"
,
"min_new_tokens"
,
"stop"
,
"stop"
,
"stop_token_ids"
,
"stop_token_ids"
,
"temperature"
,
"temperature"
,
...
...
python/sglang/lang/ir.py
View file @
2422de51
...
@@ -17,6 +17,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
...
@@ -17,6 +17,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
SglSamplingParams
:
class
SglSamplingParams
:
max_new_tokens
:
int
=
128
max_new_tokens
:
int
=
128
min_new_tokens
:
int
=
0
stop
:
Union
[
str
,
List
[
str
]]
=
()
stop
:
Union
[
str
,
List
[
str
]]
=
()
stop_token_ids
:
Optional
[
List
[
int
]]
=
()
stop_token_ids
:
Optional
[
List
[
int
]]
=
()
temperature
:
float
=
1.0
temperature
:
float
=
1.0
...
@@ -39,6 +40,7 @@ class SglSamplingParams:
...
@@ -39,6 +40,7 @@ class SglSamplingParams:
def
clone
(
self
):
def
clone
(
self
):
return
SglSamplingParams
(
return
SglSamplingParams
(
self
.
max_new_tokens
,
self
.
max_new_tokens
,
self
.
min_new_tokens
,
self
.
stop
,
self
.
stop
,
self
.
stop_token_ids
,
self
.
stop_token_ids
,
self
.
temperature
,
self
.
temperature
,
...
@@ -113,6 +115,7 @@ class SglSamplingParams:
...
@@ -113,6 +115,7 @@ class SglSamplingParams:
def
to_srt_kwargs
(
self
):
def
to_srt_kwargs
(
self
):
return
{
return
{
"max_new_tokens"
:
self
.
max_new_tokens
,
"max_new_tokens"
:
self
.
max_new_tokens
,
"min_new_tokens"
:
self
.
min_new_tokens
,
"stop"
:
self
.
stop
,
"stop"
:
self
.
stop
,
"stop_token_ids"
:
self
.
stop_token_ids
,
"stop_token_ids"
:
self
.
stop_token_ids
,
"temperature"
:
self
.
temperature
,
"temperature"
:
self
.
temperature
,
...
@@ -424,6 +427,7 @@ class SglGen(SglExpr):
...
@@ -424,6 +427,7 @@ class SglGen(SglExpr):
self
,
self
,
name
:
Optional
[
str
]
=
None
,
name
:
Optional
[
str
]
=
None
,
max_new_tokens
:
Optional
[
int
]
=
None
,
max_new_tokens
:
Optional
[
int
]
=
None
,
min_new_tokens
:
Optional
[
int
]
=
None
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
temperature
:
Optional
[
float
]
=
None
,
temperature
:
Optional
[
float
]
=
None
,
...
@@ -446,6 +450,7 @@ class SglGen(SglExpr):
...
@@ -446,6 +450,7 @@ class SglGen(SglExpr):
self
.
name
=
name
self
.
name
=
name
self
.
sampling_params
=
SglSamplingParams
(
self
.
sampling_params
=
SglSamplingParams
(
max_new_tokens
=
max_new_tokens
,
max_new_tokens
=
max_new_tokens
,
min_new_tokens
=
min_new_tokens
,
stop
=
stop
,
stop
=
stop
,
stop_token_ids
=
stop_token_ids
,
stop_token_ids
=
stop_token_ids
,
temperature
=
temperature
,
temperature
=
temperature
,
...
...
python/sglang/test/test_programs.py
View file @
2422de51
...
@@ -517,3 +517,36 @@ def test_hellaswag_select():
...
@@ -517,3 +517,36 @@ def test_hellaswag_select():
accuracy
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
accuracy
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
return
accuracy
,
latency
return
accuracy
,
latency
def
test_gen_min_new_tokens
():
"""
Validate sgl.gen(min_tokens) functionality.
The test asks a question where, without a min_tokens constraint, the generated answer is expected to be short.
By enforcing the min_tokens parameter, we ensure the generated answer has at least the specified number of tokens.
We verify that the number of tokens in the answer is >= the min_tokens threshold.
"""
import
sglang
as
sgl
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
model_path
=
sgl
.
global_config
.
default_backend
.
endpoint
.
get_model_name
()
MIN_TOKENS
,
MAX_TOKENS
=
64
,
128
@
sgl
.
function
def
convo_1
(
s
):
s
+=
sgl
.
user
(
"What is the capital of the United States?"
)
s
+=
sgl
.
assistant
(
sgl
.
gen
(
"answer"
,
min_tokens
=
MIN_TOKENS
,
max_tokens
=
MAX_TOKENS
)
)
def
assert_min_tokens
(
tokenizer
,
text
):
token_ids
=
tokenizer
.
encode
(
text
)
assert
(
len
(
token_ids
)
>=
MIN_TOKENS
),
f
"Generated
{
len
(
token_ids
)
}
tokens, min required:
{
MIN_TOKENS
}
. Text:
{
text
}
"
tokenizer
=
get_tokenizer
(
model_path
)
state
=
convo_1
.
run
()
assert_min_tokens
(
tokenizer
,
state
[
"answer"
])
test/lang/test_srt_backend.py
View file @
2422de51
...
@@ -7,6 +7,7 @@ from sglang.test.test_programs import (
...
@@ -7,6 +7,7 @@ from sglang.test.test_programs import (
test_dtype_gen
,
test_dtype_gen
,
test_expert_answer
,
test_expert_answer
,
test_few_shot_qa
,
test_few_shot_qa
,
test_gen_min_new_tokens
,
test_hellaswag_select
,
test_hellaswag_select
,
test_mt_bench
,
test_mt_bench
,
test_parallel_decoding
,
test_parallel_decoding
,
...
@@ -69,6 +70,9 @@ class TestSRTBackend(unittest.TestCase):
...
@@ -69,6 +70,9 @@ class TestSRTBackend(unittest.TestCase):
accuracy
,
latency
=
test_hellaswag_select
()
accuracy
,
latency
=
test_hellaswag_select
()
assert
accuracy
>
0.71
,
f
"
{
accuracy
=
}
"
assert
accuracy
>
0.71
,
f
"
{
accuracy
=
}
"
def
test_gen_min_new_tokens
(
self
):
test_gen_min_new_tokens
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment