Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dde8bb16
"src/array/cuda/vscode:/vscode.git/clone" did not exist on "0114f4fd79ce8552533b063b5a75ac9c2a3f9b54"
Unverified
Commit
dde8bb16
authored
Oct 05, 2024
by
Byron Hsu
Committed by
GitHub
Oct 05, 2024
Browse files
default sampling param should be deepcopied (#1581)
parent
8ac3ccc0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
12 deletions
+34
-12
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+15
-6
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+15
-4
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+3
-1
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-1
No files found.
python/sglang/lang/interpreter.py
View file @
dde8bb16
...
...
@@ -2,6 +2,7 @@
import
asyncio
import
contextvars
import
copy
import
multiprocessing
import
queue
import
threading
...
...
@@ -652,7 +653,19 @@ class StreamExecutor:
self
.
_init_var_event
(
e
)
def
_resolve_sampling_params
(
self
,
sampling_params
):
clone
=
None
"""
Construct sampling param based on default + override values
The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)
, and `sampling_params` contains the override values from sgl.gen().
Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.
It also extends the stop tokens based on the chat template.
"""
# deepcopy is required because the dict has lists inside
clone
=
copy
.
deepcopy
(
self
.
default_sampling_para
)
for
item
in
[
"max_new_tokens"
,
"stop"
,
...
...
@@ -674,20 +687,16 @@ class StreamExecutor:
]:
value
=
getattr
(
sampling_params
,
item
,
None
)
if
value
is
not
None
:
if
clone
is
None
:
clone
=
self
.
default_sampling_para
.
clone
()
setattr
(
clone
,
item
,
value
)
if
self
.
chat_template
.
stop_str
:
if
not
clone
:
clone
=
self
.
default_sampling_para
.
clone
()
if
clone
.
stop
==
():
clone
.
stop
=
[]
elif
isinstance
(
clone
.
stop
,
str
):
clone
.
stop
=
[
clone
.
stop
]
clone
.
stop
+=
self
.
chat_template
.
stop_str
return
clone
or
self
.
default_sampling_para
return
clone
def
__del__
(
self
):
self
.
end
()
...
...
python/sglang/lang/ir.py
View file @
dde8bb16
...
...
@@ -150,8 +150,8 @@ class SglFunction:
self
,
*
args
,
max_new_tokens
:
int
=
128
,
stop
:
Union
[
str
,
List
[
str
]]
=
[]
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
[]
,
stop
:
Union
[
str
,
List
[
str
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
...
...
@@ -169,6 +169,12 @@ class SglFunction:
):
from
sglang.lang.interpreter
import
run_program
# avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/
if
stop
is
None
:
stop
=
[]
if
stop_token_ids
is
None
:
stop_token_ids
=
[]
default_sampling_para
=
SglSamplingParams
(
max_new_tokens
=
max_new_tokens
,
stop
=
stop
,
...
...
@@ -193,8 +199,8 @@ class SglFunction:
batch_kwargs
,
*
,
max_new_tokens
:
int
=
128
,
stop
:
Union
[
str
,
List
[
str
]]
=
()
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
[]
,
stop
:
Union
[
str
,
List
[
str
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
...
...
@@ -212,6 +218,11 @@ class SglFunction:
):
from
sglang.lang.interpreter
import
run_program_batch
if
stop
is
None
:
stop
=
[]
if
stop_token_ids
is
None
:
stop_token_ids
=
[]
assert
isinstance
(
batch_kwargs
,
(
list
,
tuple
))
if
len
(
batch_kwargs
)
==
0
:
return
[]
...
...
python/sglang/srt/sampling/sampling_params.py
View file @
dde8bb16
...
...
@@ -26,7 +26,7 @@ class SamplingParams:
max_new_tokens
:
int
=
128
,
min_new_tokens
:
int
=
0
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
[]
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
...
...
@@ -41,6 +41,8 @@ class SamplingParams:
n
:
int
=
1
,
json_schema
:
Optional
[
str
]
=
None
,
)
->
None
:
if
stop_token_ids
is
None
:
stop_token_ids
=
[]
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_k
=
top_k
...
...
python/sglang/test/test_utils.py
View file @
dde8bb16
...
...
@@ -85,7 +85,7 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None
def
call_generate_outlines
(
prompt
,
temperature
,
max_tokens
,
stop
=
[]
,
regex
=
None
,
n
=
1
,
url
=
None
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
regex
=
None
,
n
=
1
,
url
=
None
):
assert
url
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment