Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dde8bb16
Unverified
Commit
dde8bb16
authored
Oct 05, 2024
by
Byron Hsu
Committed by
GitHub
Oct 05, 2024
Browse files
default sampling param should be deepcopied (#1581)
parent
8ac3ccc0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
12 deletions
+34
-12
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+15
-6
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+15
-4
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+3
-1
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-1
No files found.
python/sglang/lang/interpreter.py
View file @
dde8bb16
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
asyncio
import
asyncio
import
contextvars
import
contextvars
import
copy
import
multiprocessing
import
multiprocessing
import
queue
import
queue
import
threading
import
threading
...
@@ -652,7 +653,19 @@ class StreamExecutor:
...
@@ -652,7 +653,19 @@ class StreamExecutor:
self
.
_init_var_event
(
e
)
self
.
_init_var_event
(
e
)
def
_resolve_sampling_params
(
self
,
sampling_params
):
def
_resolve_sampling_params
(
self
,
sampling_params
):
clone
=
None
"""
Construct sampling param based on default + override values
The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)
, and `sampling_params` contains the override values from sgl.gen().
Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.
It also extends the stop tokens based on the chat template.
"""
# deepcopy is required because the dict has lists inside
clone
=
copy
.
deepcopy
(
self
.
default_sampling_para
)
for
item
in
[
for
item
in
[
"max_new_tokens"
,
"max_new_tokens"
,
"stop"
,
"stop"
,
...
@@ -674,20 +687,16 @@ class StreamExecutor:
...
@@ -674,20 +687,16 @@ class StreamExecutor:
]:
]:
value
=
getattr
(
sampling_params
,
item
,
None
)
value
=
getattr
(
sampling_params
,
item
,
None
)
if
value
is
not
None
:
if
value
is
not
None
:
if
clone
is
None
:
clone
=
self
.
default_sampling_para
.
clone
()
setattr
(
clone
,
item
,
value
)
setattr
(
clone
,
item
,
value
)
if
self
.
chat_template
.
stop_str
:
if
self
.
chat_template
.
stop_str
:
if
not
clone
:
clone
=
self
.
default_sampling_para
.
clone
()
if
clone
.
stop
==
():
if
clone
.
stop
==
():
clone
.
stop
=
[]
clone
.
stop
=
[]
elif
isinstance
(
clone
.
stop
,
str
):
elif
isinstance
(
clone
.
stop
,
str
):
clone
.
stop
=
[
clone
.
stop
]
clone
.
stop
=
[
clone
.
stop
]
clone
.
stop
+=
self
.
chat_template
.
stop_str
clone
.
stop
+=
self
.
chat_template
.
stop_str
return
clone
or
self
.
default_sampling_para
return
clone
def
__del__
(
self
):
def
__del__
(
self
):
self
.
end
()
self
.
end
()
...
...
python/sglang/lang/ir.py
View file @
dde8bb16
...
@@ -150,8 +150,8 @@ class SglFunction:
...
@@ -150,8 +150,8 @@ class SglFunction:
self
,
self
,
*
args
,
*
args
,
max_new_tokens
:
int
=
128
,
max_new_tokens
:
int
=
128
,
stop
:
Union
[
str
,
List
[
str
]]
=
[]
,
stop
:
Union
[
str
,
List
[
str
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
[]
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
temperature
:
float
=
1.0
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
top_k
:
int
=
-
1
,
...
@@ -169,6 +169,12 @@ class SglFunction:
...
@@ -169,6 +169,12 @@ class SglFunction:
):
):
from
sglang.lang.interpreter
import
run_program
from
sglang.lang.interpreter
import
run_program
# avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/
if
stop
is
None
:
stop
=
[]
if
stop_token_ids
is
None
:
stop_token_ids
=
[]
default_sampling_para
=
SglSamplingParams
(
default_sampling_para
=
SglSamplingParams
(
max_new_tokens
=
max_new_tokens
,
max_new_tokens
=
max_new_tokens
,
stop
=
stop
,
stop
=
stop
,
...
@@ -193,8 +199,8 @@ class SglFunction:
...
@@ -193,8 +199,8 @@ class SglFunction:
batch_kwargs
,
batch_kwargs
,
*
,
*
,
max_new_tokens
:
int
=
128
,
max_new_tokens
:
int
=
128
,
stop
:
Union
[
str
,
List
[
str
]]
=
()
,
stop
:
Union
[
str
,
List
[
str
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
[]
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
temperature
:
float
=
1.0
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
top_k
:
int
=
-
1
,
...
@@ -212,6 +218,11 @@ class SglFunction:
...
@@ -212,6 +218,11 @@ class SglFunction:
):
):
from
sglang.lang.interpreter
import
run_program_batch
from
sglang.lang.interpreter
import
run_program_batch
if
stop
is
None
:
stop
=
[]
if
stop_token_ids
is
None
:
stop_token_ids
=
[]
assert
isinstance
(
batch_kwargs
,
(
list
,
tuple
))
assert
isinstance
(
batch_kwargs
,
(
list
,
tuple
))
if
len
(
batch_kwargs
)
==
0
:
if
len
(
batch_kwargs
)
==
0
:
return
[]
return
[]
...
...
python/sglang/srt/sampling/sampling_params.py
View file @
dde8bb16
...
@@ -26,7 +26,7 @@ class SamplingParams:
...
@@ -26,7 +26,7 @@ class SamplingParams:
max_new_tokens
:
int
=
128
,
max_new_tokens
:
int
=
128
,
min_new_tokens
:
int
=
0
,
min_new_tokens
:
int
=
0
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
[]
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
temperature
:
float
=
1.0
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
top_k
:
int
=
-
1
,
...
@@ -41,6 +41,8 @@ class SamplingParams:
...
@@ -41,6 +41,8 @@ class SamplingParams:
n
:
int
=
1
,
n
:
int
=
1
,
json_schema
:
Optional
[
str
]
=
None
,
json_schema
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
if
stop_token_ids
is
None
:
stop_token_ids
=
[]
self
.
temperature
=
temperature
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
top_k
=
top_k
...
...
python/sglang/test/test_utils.py
View file @
dde8bb16
...
@@ -85,7 +85,7 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None
...
@@ -85,7 +85,7 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None
def
call_generate_outlines
(
def
call_generate_outlines
(
prompt
,
temperature
,
max_tokens
,
stop
=
[]
,
regex
=
None
,
n
=
1
,
url
=
None
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
regex
=
None
,
n
=
1
,
url
=
None
):
):
assert
url
is
not
None
assert
url
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment