Unverified Commit dde8bb16 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

default sampling param should be deepcopied (#1581)

parent 8ac3ccc0
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import asyncio import asyncio
import contextvars import contextvars
import copy
import multiprocessing import multiprocessing
import queue import queue
import threading import threading
...@@ -652,7 +653,19 @@ class StreamExecutor: ...@@ -652,7 +653,19 @@ class StreamExecutor:
self._init_var_event(e) self._init_var_event(e)
def _resolve_sampling_params(self, sampling_params): def _resolve_sampling_params(self, sampling_params):
clone = None """
Construct sampling param based on default + override values
The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)
, and `sampling_params` contains the override values from sgl.gen().
Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.
It also extends the stop tokens based on the chat template.
"""
# deepcopy is required because the dict has lists inside
clone = copy.deepcopy(self.default_sampling_para)
for item in [ for item in [
"max_new_tokens", "max_new_tokens",
"stop", "stop",
...@@ -674,20 +687,16 @@ class StreamExecutor: ...@@ -674,20 +687,16 @@ class StreamExecutor:
]: ]:
value = getattr(sampling_params, item, None) value = getattr(sampling_params, item, None)
if value is not None: if value is not None:
if clone is None:
clone = self.default_sampling_para.clone()
setattr(clone, item, value) setattr(clone, item, value)
if self.chat_template.stop_str: if self.chat_template.stop_str:
if not clone:
clone = self.default_sampling_para.clone()
if clone.stop == (): if clone.stop == ():
clone.stop = [] clone.stop = []
elif isinstance(clone.stop, str): elif isinstance(clone.stop, str):
clone.stop = [clone.stop] clone.stop = [clone.stop]
clone.stop += self.chat_template.stop_str clone.stop += self.chat_template.stop_str
return clone or self.default_sampling_para return clone
def __del__(self): def __del__(self):
self.end() self.end()
......
...@@ -150,8 +150,8 @@ class SglFunction: ...@@ -150,8 +150,8 @@ class SglFunction:
self, self,
*args, *args,
max_new_tokens: int = 128, max_new_tokens: int = 128,
stop: Union[str, List[str]] = [], stop: Union[str, List[str]] = None,
stop_token_ids: Optional[List[int]] = [], stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
...@@ -169,6 +169,12 @@ class SglFunction: ...@@ -169,6 +169,12 @@ class SglFunction:
): ):
from sglang.lang.interpreter import run_program from sglang.lang.interpreter import run_program
# avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/
if stop is None:
stop = []
if stop_token_ids is None:
stop_token_ids = []
default_sampling_para = SglSamplingParams( default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
stop=stop, stop=stop,
...@@ -193,8 +199,8 @@ class SglFunction: ...@@ -193,8 +199,8 @@ class SglFunction:
batch_kwargs, batch_kwargs,
*, *,
max_new_tokens: int = 128, max_new_tokens: int = 128,
stop: Union[str, List[str]] = (), stop: Union[str, List[str]] = None,
stop_token_ids: Optional[List[int]] = [], stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
...@@ -212,6 +218,11 @@ class SglFunction: ...@@ -212,6 +218,11 @@ class SglFunction:
): ):
from sglang.lang.interpreter import run_program_batch from sglang.lang.interpreter import run_program_batch
if stop is None:
stop = []
if stop_token_ids is None:
stop_token_ids = []
assert isinstance(batch_kwargs, (list, tuple)) assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0: if len(batch_kwargs) == 0:
return [] return []
......
...@@ -26,7 +26,7 @@ class SamplingParams: ...@@ -26,7 +26,7 @@ class SamplingParams:
max_new_tokens: int = 128, max_new_tokens: int = 128,
min_new_tokens: int = 0, min_new_tokens: int = 0,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = [], stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
...@@ -41,6 +41,8 @@ class SamplingParams: ...@@ -41,6 +41,8 @@ class SamplingParams:
n: int = 1, n: int = 1,
json_schema: Optional[str] = None, json_schema: Optional[str] = None,
) -> None: ) -> None:
if stop_token_ids is None:
stop_token_ids = []
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
......
...@@ -85,7 +85,7 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None ...@@ -85,7 +85,7 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None
def call_generate_outlines( def call_generate_outlines(
prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None prompt, temperature, max_tokens, stop=None, regex=None, n=1, url=None
): ):
assert url is not None assert url is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment