Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a34dd86a
Unverified
Commit
a34dd86a
authored
Aug 14, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 14, 2024
Browse files
Use `dtype` to control generate (#1082)
Co-authored-by:
zhyncs
<
me@zhyncs.com
>
parent
67c0d832
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
110 additions
and
88 deletions
+110
-88
benchmark/json_decode_regex/bench_other.py
benchmark/json_decode_regex/bench_other.py
+4
-4
benchmark/json_decode_regex/bench_sglang.py
benchmark/json_decode_regex/bench_sglang.py
+6
-6
python/sglang/api.py
python/sglang/api.py
+1
-1
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+1
-1
python/sglang/lang/backend/runtime_endpoint.py
python/sglang/lang/backend/runtime_endpoint.py
+60
-49
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+3
-3
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-10
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-7
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+1
-0
python/sglang/srt/sampling_params.py
python/sglang/srt/sampling_params.py
+0
-4
python/sglang/test/test_programs.py
python/sglang/test/test_programs.py
+26
-2
test/lang/test_srt_backend.py
test/lang/test_srt_backend.py
+4
-1
No files found.
benchmark/json_decode_regex/bench_other.py
View file @
a34dd86a
...
@@ -6,11 +6,11 @@ from functools import partial
...
@@ -6,11 +6,11 @@ from functools import partial
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
sglang.lang.ir
import
REGEX_FLOAT
,
REGEX_INT
,
REGEX_STR
ING
from
sglang.lang.ir
import
REGEX_FLOAT
,
REGEX_INT
,
REGEX_STR
from
sglang.test.test_utils
import
add_common_other_args_and_parse
,
get_call_generate
from
sglang.test.test_utils
import
add_common_other_args_and_parse
,
get_call_generate
from
sglang.utils
import
dump_state_text
,
read_jsonl
from
sglang.utils
import
dump_state_text
,
read_jsonl
REGEX_LIST
=
r
"\[("
+
REGEX_STR
ING
+
", )*"
+
REGEX_STR
ING
+
r
"\]"
REGEX_LIST
=
r
"\[("
+
REGEX_STR
+
", )*"
+
REGEX_STR
+
r
"\]"
# fmt: off
# fmt: off
...
@@ -20,9 +20,9 @@ def json_decode(document, generate):
...
@@ -20,9 +20,9 @@ def json_decode(document, generate):
s
+=
"Here is the name, country, and symbol of the city in JSON format.
\n
"
s
+=
"Here is the name, country, and symbol of the city in JSON format.
\n
"
s
+=
"{
\n
"
s
+=
"{
\n
"
s
+=
' "name": '
s
+=
' "name": '
s
+=
generate
(
s
,
max_tokens
=
8
,
regex
=
REGEX_STR
ING
+
","
)
+
"
\n
"
s
+=
generate
(
s
,
max_tokens
=
8
,
regex
=
REGEX_STR
+
","
)
+
"
\n
"
s
+=
' "country": '
s
+=
' "country": '
s
+=
generate
(
s
,
max_tokens
=
8
,
regex
=
REGEX_STR
ING
+
","
)
+
"
\n
"
s
+=
generate
(
s
,
max_tokens
=
8
,
regex
=
REGEX_STR
+
","
)
+
"
\n
"
s
+=
' "latitude": '
s
+=
' "latitude": '
s
+=
generate
(
s
,
max_tokens
=
8
,
regex
=
REGEX_FLOAT
+
","
)
+
"
\n
"
s
+=
generate
(
s
,
max_tokens
=
8
,
regex
=
REGEX_FLOAT
+
","
)
+
"
\n
"
s
+=
' "population": '
s
+=
' "population": '
...
...
benchmark/json_decode_regex/bench_sglang.py
View file @
a34dd86a
...
@@ -3,14 +3,14 @@ import json
...
@@ -3,14 +3,14 @@ import json
import
time
import
time
import
sglang
as
sgl
import
sglang
as
sgl
from
sglang.lang.ir
import
REGEX_FLOAT
,
REGEX_INT
,
REGEX_STR
ING
from
sglang.lang.ir
import
REGEX_FLOAT
,
REGEX_INT
,
REGEX_STR
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
add_common_sglang_args_and_parse
,
select_sglang_backend
,
select_sglang_backend
,
)
)
from
sglang.utils
import
dump_state_text
,
read_jsonl
from
sglang.utils
import
dump_state_text
,
read_jsonl
REGEX_LIST
=
r
"\[("
+
REGEX_STR
ING
+
", )*"
+
REGEX_STR
ING
+
r
"\]"
REGEX_LIST
=
r
"\[("
+
REGEX_STR
+
", )*"
+
REGEX_STR
+
r
"\]"
# fmt: off
# fmt: off
@
sgl
.
function
@
sgl
.
function
...
@@ -18,8 +18,8 @@ def json_warm_up(s):
...
@@ -18,8 +18,8 @@ def json_warm_up(s):
s
+=
"The information about Hogwarts is in the following JSON format.
\n
"
s
+=
"The information about Hogwarts is in the following JSON format.
\n
"
with
s
.
var_scope
(
"json_output"
):
with
s
.
var_scope
(
"json_output"
):
s
+=
"{
\n
"
s
+=
"{
\n
"
s
+=
' "name": '
+
sgl
.
gen
(
"name"
,
max_tokens
=
8
,
regex
=
REGEX_STR
ING
+
","
)
+
"
\n
"
s
+=
' "name": '
+
sgl
.
gen
(
"name"
,
max_tokens
=
8
,
regex
=
REGEX_STR
+
","
)
+
"
\n
"
s
+=
' "country": '
+
sgl
.
gen
(
"country"
,
max_tokens
=
8
,
regex
=
REGEX_STR
ING
+
","
)
+
"
\n
"
s
+=
' "country": '
+
sgl
.
gen
(
"country"
,
max_tokens
=
8
,
regex
=
REGEX_STR
+
","
)
+
"
\n
"
s
+=
' "latitude": '
+
sgl
.
gen
(
"latitude"
,
max_tokens
=
8
,
regex
=
REGEX_FLOAT
+
","
)
+
"
\n
"
s
+=
' "latitude": '
+
sgl
.
gen
(
"latitude"
,
max_tokens
=
8
,
regex
=
REGEX_FLOAT
+
","
)
+
"
\n
"
s
+=
' "population": '
+
sgl
.
gen
(
"population"
,
max_tokens
=
8
,
regex
=
REGEX_INT
+
","
)
+
"
\n
"
s
+=
' "population": '
+
sgl
.
gen
(
"population"
,
max_tokens
=
8
,
regex
=
REGEX_INT
+
","
)
+
"
\n
"
s
+=
' "top 3 landmarks": '
+
sgl
.
gen
(
"landmarks"
,
max_tokens
=
24
,
regex
=
REGEX_LIST
)
+
"
\n
"
s
+=
' "top 3 landmarks": '
+
sgl
.
gen
(
"landmarks"
,
max_tokens
=
24
,
regex
=
REGEX_LIST
)
+
"
\n
"
...
@@ -35,8 +35,8 @@ def json_decode(s, document):
...
@@ -35,8 +35,8 @@ def json_decode(s, document):
s
+=
"Here is the name, country, and symbol of the city in JSON format.
\n
"
s
+=
"Here is the name, country, and symbol of the city in JSON format.
\n
"
with
s
.
var_scope
(
"json_output"
):
with
s
.
var_scope
(
"json_output"
):
s
+=
"{
\n
"
s
+=
"{
\n
"
s
+=
' "name": '
+
sgl
.
gen
(
"name"
,
max_tokens
=
8
,
regex
=
REGEX_STR
ING
+
","
)
+
"
\n
"
s
+=
' "name": '
+
sgl
.
gen
(
"name"
,
max_tokens
=
8
,
regex
=
REGEX_STR
+
","
)
+
"
\n
"
s
+=
' "country": '
+
sgl
.
gen
(
"country"
,
max_tokens
=
8
,
regex
=
REGEX_STR
ING
+
","
)
+
"
\n
"
s
+=
' "country": '
+
sgl
.
gen
(
"country"
,
max_tokens
=
8
,
regex
=
REGEX_STR
+
","
)
+
"
\n
"
s
+=
' "latitude": '
+
sgl
.
gen
(
"latitude"
,
max_tokens
=
8
,
regex
=
REGEX_FLOAT
+
","
)
+
"
\n
"
s
+=
' "latitude": '
+
sgl
.
gen
(
"latitude"
,
max_tokens
=
8
,
regex
=
REGEX_FLOAT
+
","
)
+
"
\n
"
s
+=
' "population": '
+
sgl
.
gen
(
"population"
,
max_tokens
=
8
,
regex
=
REGEX_INT
+
","
)
+
"
\n
"
s
+=
' "population": '
+
sgl
.
gen
(
"population"
,
max_tokens
=
8
,
regex
=
REGEX_INT
+
","
)
+
"
\n
"
s
+=
' "top 3 landmarks": '
+
sgl
.
gen
(
"landmarks"
,
max_tokens
=
24
,
regex
=
REGEX_LIST
)
+
"
\n
"
s
+=
' "top 3 landmarks": '
+
sgl
.
gen
(
"landmarks"
,
max_tokens
=
24
,
regex
=
REGEX_LIST
)
+
"
\n
"
...
...
python/sglang/api.py
View file @
a34dd86a
...
@@ -72,7 +72,7 @@ def gen(
...
@@ -72,7 +72,7 @@ def gen(
logprob_start_len
:
Optional
[
int
]
=
None
,
logprob_start_len
:
Optional
[
int
]
=
None
,
top_logprobs_num
:
Optional
[
int
]
=
None
,
top_logprobs_num
:
Optional
[
int
]
=
None
,
return_text_in_logprobs
:
Optional
[
bool
]
=
None
,
return_text_in_logprobs
:
Optional
[
bool
]
=
None
,
dtype
:
Optional
[
type
]
=
None
,
dtype
:
Optional
[
Union
[
type
,
str
]
]
=
None
,
choices
:
Optional
[
List
[
str
]]
=
None
,
choices
:
Optional
[
List
[
str
]]
=
None
,
choices_method
:
Optional
[
ChoicesSamplingMethod
]
=
None
,
choices_method
:
Optional
[
ChoicesSamplingMethod
]
=
None
,
regex
:
Optional
[
str
]
=
None
,
regex
:
Optional
[
str
]
=
None
,
...
...
python/sglang/bench_latency.py
View file @
a34dd86a
...
@@ -195,7 +195,7 @@ def extend(reqs, model_runner):
...
@@ -195,7 +195,7 @@ def extend(reqs, model_runner):
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
tree_cache
=
None
,
tree_cache
=
None
,
)
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
,
None
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
return
next_token_ids
,
output
.
next_token_logits
,
batch
return
next_token_ids
,
output
.
next_token_logits
,
batch
...
...
python/sglang/lang/backend/runtime_endpoint.py
View file @
a34dd86a
import
json
import
json
import
warnings
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.lang.backend.base_backend
import
BaseBackend
from
sglang.lang.backend.base_backend
import
BaseBackend
from
sglang.lang.chat_template
import
get_chat_template_by_model_path
from
sglang.lang.chat_template
import
get_chat_template_by_model_path
from
sglang.lang.choices
import
(
from
sglang.lang.choices
import
ChoicesDecision
,
ChoicesSamplingMethod
ChoicesDecision
,
ChoicesSamplingMethod
,
token_length_normalized
,
)
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.ir
import
SglSamplingParams
from
sglang.lang.ir
import
(
REGEX_BOOL
,
REGEX_FLOAT
,
REGEX_INT
,
REGEX_STR
,
SglSamplingParams
,
)
from
sglang.utils
import
http_request
from
sglang.utils
import
http_request
class
RuntimeEndpoint
(
BaseBackend
):
class
RuntimeEndpoint
(
BaseBackend
):
def
__init__
(
def
__init__
(
self
,
self
,
base_url
:
str
,
base_url
:
str
,
...
@@ -95,32 +97,52 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -95,32 +97,52 @@ class RuntimeEndpoint(BaseBackend):
)
)
self
.
_assert_success
(
res
)
self
.
_assert_success
(
res
)
def
_handle_dtype_to_regex
(
self
,
sampling_params
:
SglSamplingParams
):
if
sampling_params
.
dtype
is
None
:
return
if
sampling_params
.
stop
==
():
sampling_params
.
stop
=
[]
dtype_regex
=
None
if
sampling_params
.
dtype
in
[
"int"
,
int
]:
dtype_regex
=
REGEX_INT
sampling_params
.
stop
.
extend
([
" "
,
"
\n
"
])
elif
sampling_params
.
dtype
in
[
"float"
,
float
]:
dtype_regex
=
REGEX_FLOAT
sampling_params
.
stop
.
extend
([
" "
,
"
\n
"
])
elif
sampling_params
.
dtype
in
[
"str"
,
str
]:
dtype_regex
=
REGEX_STR
elif
sampling_params
.
dtype
in
[
"bool"
,
bool
]:
dtype_regex
=
REGEX_BOOL
else
:
raise
RuntimeError
(
f
"Invalid dtype:
{
sampling_params
.
dtype
}
"
)
if
dtype_regex
is
not
None
and
sampling_params
.
regex
is
not
None
:
warnings
.
warn
(
f
"Both dtype and regex are set. Only dtype will be used. dtype:
{
sampling_params
.
dtype
}
, regex:
{
sampling_params
.
regex
}
"
)
sampling_params
.
regex
=
dtype_regex
def
generate
(
def
generate
(
self
,
self
,
s
:
StreamExecutor
,
s
:
StreamExecutor
,
sampling_params
:
SglSamplingParams
,
sampling_params
:
SglSamplingParams
,
):
):
if
sampling_params
.
dtype
is
None
:
self
.
_handle_dtype_to_regex
(
sampling_params
)
data
=
{
data
=
{
"text"
:
s
.
text_
,
"text"
:
s
.
text_
,
"sampling_params"
:
{
"sampling_params"
:
{
"skip_special_tokens"
:
global_config
.
skip_special_tokens_in_output
,
"skip_special_tokens"
:
global_config
.
skip_special_tokens_in_output
,
"spaces_between_special_tokens"
:
global_config
.
spaces_between_special_tokens_in_out
,
"spaces_between_special_tokens"
:
global_config
.
spaces_between_special_tokens_in_out
,
**
sampling_params
.
to_srt_kwargs
(),
**
sampling_params
.
to_srt_kwargs
(),
},
},
}
}
elif
sampling_params
.
dtype
in
[
int
,
"int"
]:
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"skip_special_tokens"
:
global_config
.
skip_special_tokens_in_output
,
"spaces_between_special_tokens"
:
global_config
.
spaces_between_special_tokens_in_out
,
"dtype"
:
"int"
,
**
sampling_params
.
to_srt_kwargs
(),
},
}
else
:
raise
RuntimeError
(
f
"Invalid dtype:
{
sampling_params
.
dtype
}
"
)
for
item
in
[
for
item
in
[
"return_logprob"
,
"return_logprob"
,
...
@@ -151,27 +173,16 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -151,27 +173,16 @@ class RuntimeEndpoint(BaseBackend):
s
:
StreamExecutor
,
s
:
StreamExecutor
,
sampling_params
:
SglSamplingParams
,
sampling_params
:
SglSamplingParams
,
):
):
if
sampling_params
.
dtype
is
None
:
self
.
_handle_dtype_to_regex
(
sampling_params
)
data
=
{
"text"
:
s
.
text_
,
data
=
{
"sampling_params"
:
{
"text"
:
s
.
text_
,
"skip_special_tokens"
:
global_config
.
skip_special_tokens_in_output
,
"sampling_params"
:
{
"spaces_between_special_tokens"
:
global_config
.
spaces_between_special_tokens_in_out
,
"skip_special_tokens"
:
global_config
.
skip_special_tokens_in_output
,
**
sampling_params
.
to_srt_kwargs
(),
"spaces_between_special_tokens"
:
global_config
.
spaces_between_special_tokens_in_out
,
},
**
sampling_params
.
to_srt_kwargs
(),
}
},
elif
sampling_params
.
dtype
in
[
int
,
"int"
]:
}
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"skip_special_tokens"
:
global_config
.
skip_special_tokens_in_output
,
"spaces_between_special_tokens"
:
global_config
.
spaces_between_special_tokens_in_out
,
"dtype"
:
"int"
,
**
sampling_params
.
to_srt_kwargs
(),
},
}
else
:
raise
RuntimeError
(
f
"Invalid dtype:
{
sampling_params
.
dtype
}
"
)
for
item
in
[
for
item
in
[
"return_logprob"
,
"return_logprob"
,
...
...
python/sglang/lang/ir.py
View file @
a34dd86a
...
@@ -8,10 +8,10 @@ from typing import List, Optional, Union
...
@@ -8,10 +8,10 @@ from typing import List, Optional, Union
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.lang.choices
import
ChoicesSamplingMethod
from
sglang.lang.choices
import
ChoicesSamplingMethod
REGEX_INT
=
r
"[-+]?[0-9]+"
REGEX_INT
=
r
"[-+]?[0-9]+
[ \n]*
"
REGEX_FLOAT
=
r
"[-+]?[0-9]*\.?[0-9]+"
REGEX_FLOAT
=
r
"[-+]?[0-9]*\.?[0-9]+
[ \n]*
"
REGEX_BOOL
=
r
"(True|False)"
REGEX_BOOL
=
r
"(True|False)"
REGEX_STR
ING
=
r
"\"[\w\d\s]*\""
# bugs with regex r"\".*\"" in interegular pkg
REGEX_STR
=
r
"\"[\w\d\s]*\""
# bugs with regex r"\".*\"" in interegular pkg
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
a34dd86a
...
@@ -383,7 +383,7 @@ class ScheduleBatch:
...
@@ -383,7 +383,7 @@ class ScheduleBatch:
return
out_cache_loc
return
out_cache_loc
def
batch_sampling_params
(
self
,
vocab_size
,
int_token_logit_bias
):
def
batch_sampling_params
(
self
,
vocab_size
):
device
=
"cuda"
device
=
"cuda"
bs
,
reqs
=
self
.
batch_size
(),
self
.
reqs
bs
,
reqs
=
self
.
batch_size
(),
self
.
reqs
self
.
temperatures
=
torch
.
tensor
(
self
.
temperatures
=
torch
.
tensor
(
...
@@ -419,15 +419,8 @@ class ScheduleBatch:
...
@@ -419,15 +419,8 @@ class ScheduleBatch:
# Handle logit bias but only allocate when needed
# Handle logit bias but only allocate when needed
self
.
logit_bias
=
None
self
.
logit_bias
=
None
for
i
in
range
(
bs
):
if
reqs
[
i
].
sampling_params
.
dtype
==
"int"
:
if
self
.
logit_bias
is
None
:
self
.
logit_bias
=
torch
.
zeros
(
(
bs
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
device
)
self
.
logit_bias
[
i
][:
len
(
int_token_logit_bias
)]
=
int_token_logit_bias
def
prepare_for_extend
(
self
,
vocab_size
:
int
,
int_token_logit_bias
:
torch
.
Tensor
):
def
prepare_for_extend
(
self
,
vocab_size
:
int
):
bs
=
self
.
batch_size
()
bs
=
self
.
batch_size
()
reqs
=
self
.
reqs
reqs
=
self
.
reqs
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
...
@@ -466,7 +459,7 @@ class ScheduleBatch:
...
@@ -466,7 +459,7 @@ class ScheduleBatch:
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
batch_sampling_params
(
vocab_size
,
int_token_logit_bias
)
self
.
batch_sampling_params
(
vocab_size
)
def
check_decode_mem
(
self
):
def
check_decode_mem
(
self
):
bs
=
self
.
batch_size
()
bs
=
self
.
batch_size
()
...
...
python/sglang/srt/managers/tp_worker.py
View file @
a34dd86a
...
@@ -54,7 +54,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
...
@@ -54,7 +54,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_int_token_logit_bias
,
is_multimodal_model
,
is_multimodal_model
,
set_random_seed
,
set_random_seed
,
suppress_other_loggers
,
suppress_other_loggers
,
...
@@ -132,9 +131,6 @@ class ModelTpServer:
...
@@ -132,9 +131,6 @@ class ModelTpServer:
),
),
self
.
model_runner
.
req_to_token_pool
.
size
-
1
,
self
.
model_runner
.
req_to_token_pool
.
size
-
1
,
)
)
self
.
int_token_logit_bias
=
torch
.
tensor
(
get_int_token_logit_bias
(
self
.
tokenizer
,
self
.
model_config
.
vocab_size
)
)
self
.
max_req_input_len
=
min
(
self
.
max_req_input_len
=
min
(
self
.
model_config
.
context_len
-
1
,
self
.
model_config
.
context_len
-
1
,
self
.
max_total_num_tokens
-
1
,
self
.
max_total_num_tokens
-
1
,
...
@@ -442,9 +438,7 @@ class ModelTpServer:
...
@@ -442,9 +438,7 @@ class ModelTpServer:
def
forward_prefill_batch
(
self
,
batch
:
ScheduleBatch
):
def
forward_prefill_batch
(
self
,
batch
:
ScheduleBatch
):
# Build batch tensors
# Build batch tensors
batch
.
prepare_for_extend
(
batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
)
if
self
.
model_runner
.
is_generation
:
if
self
.
model_runner
.
is_generation
:
# Forward and sample the next tokens
# Forward and sample the next tokens
...
...
python/sglang/srt/models/mixtral.py
View file @
a34dd86a
...
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (
...
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
...
...
python/sglang/srt/sampling_params.py
View file @
a34dd86a
...
@@ -36,7 +36,6 @@ class SamplingParams:
...
@@ -36,7 +36,6 @@ class SamplingParams:
ignore_eos
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
dtype
:
Optional
[
str
]
=
None
,
regex
:
Optional
[
str
]
=
None
,
regex
:
Optional
[
str
]
=
None
,
n
:
int
=
1
,
n
:
int
=
1
,
)
->
None
:
)
->
None
:
...
@@ -53,7 +52,6 @@ class SamplingParams:
...
@@ -53,7 +52,6 @@ class SamplingParams:
self
.
ignore_eos
=
ignore_eos
self
.
ignore_eos
=
ignore_eos
self
.
skip_special_tokens
=
skip_special_tokens
self
.
skip_special_tokens
=
skip_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
dtype
=
dtype
self
.
regex
=
regex
self
.
regex
=
regex
self
.
n
=
n
self
.
n
=
n
...
@@ -63,8 +61,6 @@ class SamplingParams:
...
@@ -63,8 +61,6 @@ class SamplingParams:
self
.
top_k
=
1
self
.
top_k
=
1
if
self
.
top_k
==
-
1
:
if
self
.
top_k
==
-
1
:
self
.
top_k
=
1
<<
30
# whole vocabulary
self
.
top_k
=
1
<<
30
# whole vocabulary
if
self
.
dtype
==
"int"
:
self
.
stop_strs
=
[
" "
,
"
\n
"
]
def
verify
(
self
):
def
verify
(
self
):
if
self
.
temperature
<
0.0
:
if
self
.
temperature
<
0.0
:
...
...
python/sglang/test/test_programs.py
View file @
a34dd86a
...
@@ -103,13 +103,13 @@ def test_decode_int():
...
@@ -103,13 +103,13 @@ def test_decode_int():
def
test_decode_json_regex
():
def
test_decode_json_regex
():
@
sgl
.
function
@
sgl
.
function
def
decode_json
(
s
):
def
decode_json
(
s
):
from
sglang.lang.ir
import
REGEX_FLOAT
,
REGEX_INT
,
REGEX_STR
ING
from
sglang.lang.ir
import
REGEX_FLOAT
,
REGEX_INT
,
REGEX_STR
s
+=
"Generate a JSON object to describe the basic city information of Paris.
\n
"
s
+=
"Generate a JSON object to describe the basic city information of Paris.
\n
"
with
s
.
var_scope
(
"json_output"
):
with
s
.
var_scope
(
"json_output"
):
s
+=
"{
\n
"
s
+=
"{
\n
"
s
+=
' "name": '
+
sgl
.
gen
(
regex
=
REGEX_STR
ING
+
","
)
+
"
\n
"
s
+=
' "name": '
+
sgl
.
gen
(
regex
=
REGEX_STR
+
","
)
+
"
\n
"
s
+=
' "population": '
+
sgl
.
gen
(
regex
=
REGEX_INT
+
","
)
+
"
\n
"
s
+=
' "population": '
+
sgl
.
gen
(
regex
=
REGEX_INT
+
","
)
+
"
\n
"
s
+=
' "area": '
+
sgl
.
gen
(
regex
=
REGEX_INT
+
","
)
+
"
\n
"
s
+=
' "area": '
+
sgl
.
gen
(
regex
=
REGEX_INT
+
","
)
+
"
\n
"
s
+=
' "latitude": '
+
sgl
.
gen
(
regex
=
REGEX_FLOAT
)
+
"
\n
"
s
+=
' "latitude": '
+
sgl
.
gen
(
regex
=
REGEX_FLOAT
)
+
"
\n
"
...
@@ -359,6 +359,30 @@ def test_regex():
...
@@ -359,6 +359,30 @@ def test_regex():
assert
re
.
match
(
regex
,
answer
)
assert
re
.
match
(
regex
,
answer
)
def
test_dtype_gen
():
@
sgl
.
function
def
dtype_gen
(
s
):
s
+=
"Q: What is the full name of DNS?
\n
"
s
+=
"A: The full nams is "
+
sgl
.
gen
(
"str_res"
,
dtype
=
str
,
stop
=
"
\n
"
)
+
"
\n
"
s
+=
"Q: Which year was DNS invented?
\n
"
s
+=
"A: "
+
sgl
.
gen
(
"int_res"
,
dtype
=
int
)
+
"
\n
"
s
+=
"Q: What is the value of pi?
\n
"
s
+=
"A: "
+
sgl
.
gen
(
"float_res"
,
dtype
=
float
)
+
"
\n
"
s
+=
"Q: Is the sky blue?
\n
"
s
+=
"A: "
+
sgl
.
gen
(
"bool_res"
,
dtype
=
bool
)
+
"
\n
"
state
=
dtype_gen
.
run
()
try
:
state
[
"int_res"
]
=
int
(
state
[
"int_res"
])
state
[
"float_res"
]
=
float
(
state
[
"float_res"
])
state
[
"bool_res"
]
=
bool
(
state
[
"bool_res"
])
# assert state["str_res"].startswith('"') and state["str_res"].endswith('"')
except
ValueError
:
print
(
state
)
raise
def
test_completion_speculative
():
def
test_completion_speculative
():
@
sgl
.
function
(
num_api_spec_tokens
=
64
)
@
sgl
.
function
(
num_api_spec_tokens
=
64
)
def
gen_character_spec
(
s
):
def
gen_character_spec
(
s
):
...
...
test/lang/test_srt_backend.py
View file @
a34dd86a
import
json
import
unittest
import
unittest
import
sglang
as
sgl
import
sglang
as
sgl
from
sglang.test.test_programs
import
(
from
sglang.test.test_programs
import
(
test_decode_int
,
test_decode_int
,
test_decode_json_regex
,
test_decode_json_regex
,
test_dtype_gen
,
test_expert_answer
,
test_expert_answer
,
test_few_shot_qa
,
test_few_shot_qa
,
test_mt_bench
,
test_mt_bench
,
...
@@ -59,6 +59,9 @@ class TestSRTBackend(unittest.TestCase):
...
@@ -59,6 +59,9 @@ class TestSRTBackend(unittest.TestCase):
def
test_regex
(
self
):
def
test_regex
(
self
):
test_regex
()
test_regex
()
def
test_dtype_gen
(
self
):
test_dtype_gen
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment