"src/array/cuda/vscode:/vscode.git/clone" did not exist on "0114f4fd79ce8552533b063b5a75ac9c2a3f9b54"
Unverified Commit dde8bb16 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

default sampling param should be deepcopied (#1581)

parent 8ac3ccc0
......@@ -2,6 +2,7 @@
import asyncio
import contextvars
import copy
import multiprocessing
import queue
import threading
......@@ -652,7 +653,19 @@ class StreamExecutor:
self._init_var_event(e)
def _resolve_sampling_params(self, sampling_params):
clone = None
"""
Construct sampling param based on default + override values
The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)
, and `sampling_params` contains the override values from sgl.gen().
Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.
It also extends the stop tokens based on the chat template.
"""
# deepcopy is required because the dict has lists inside
clone = copy.deepcopy(self.default_sampling_para)
for item in [
"max_new_tokens",
"stop",
......@@ -674,20 +687,16 @@ class StreamExecutor:
]:
value = getattr(sampling_params, item, None)
if value is not None:
if clone is None:
clone = self.default_sampling_para.clone()
setattr(clone, item, value)
if self.chat_template.stop_str:
if not clone:
clone = self.default_sampling_para.clone()
if clone.stop == ():
clone.stop = []
elif isinstance(clone.stop, str):
clone.stop = [clone.stop]
clone.stop += self.chat_template.stop_str
return clone or self.default_sampling_para
return clone
def __del__(self):
self.end()
......
......@@ -150,8 +150,8 @@ class SglFunction:
self,
*args,
max_new_tokens: int = 128,
stop: Union[str, List[str]] = [],
stop_token_ids: Optional[List[int]] = [],
stop: Union[str, List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
......@@ -169,6 +169,12 @@ class SglFunction:
):
from sglang.lang.interpreter import run_program
# avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/
if stop is None:
stop = []
if stop_token_ids is None:
stop_token_ids = []
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
......@@ -193,8 +199,8 @@ class SglFunction:
batch_kwargs,
*,
max_new_tokens: int = 128,
stop: Union[str, List[str]] = (),
stop_token_ids: Optional[List[int]] = [],
stop: Union[str, List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
......@@ -212,6 +218,11 @@ class SglFunction:
):
from sglang.lang.interpreter import run_program_batch
if stop is None:
stop = []
if stop_token_ids is None:
stop_token_ids = []
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
......
......@@ -26,7 +26,7 @@ class SamplingParams:
max_new_tokens: int = 128,
min_new_tokens: int = 0,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = [],
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
......@@ -41,6 +41,8 @@ class SamplingParams:
n: int = 1,
json_schema: Optional[str] = None,
) -> None:
if stop_token_ids is None:
stop_token_ids = []
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
......
......@@ -85,7 +85,7 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None
def call_generate_outlines(
prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None
prompt, temperature, max_tokens, stop=None, regex=None, n=1, url=None
):
assert url is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment