Unverified Commit a34dd86a authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Use `dtype` to control generate (#1082)


Co-authored-by: default avatarzhyncs <me@zhyncs.com>
parent 67c0d832
...@@ -6,11 +6,11 @@ from functools import partial ...@@ -6,11 +6,11 @@ from functools import partial
from tqdm import tqdm from tqdm import tqdm
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]" REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
# fmt: off # fmt: off
...@@ -20,9 +20,9 @@ def json_decode(document, generate): ...@@ -20,9 +20,9 @@ def json_decode(document, generate):
s += "Here is the name, country, and symbol of the city in JSON format.\n" s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += "{\n" s += "{\n"
s += ' "name": ' s += ' "name": '
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n" s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "country": ' s += ' "country": '
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n" s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "latitude": ' s += ' "latitude": '
s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n" s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": ' s += ' "population": '
......
...@@ -3,14 +3,14 @@ import json ...@@ -3,14 +3,14 @@ import json
import time import time
import sglang as sgl import sglang as sgl
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
) )
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]" REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
# fmt: off # fmt: off
@sgl.function @sgl.function
...@@ -18,8 +18,8 @@ def json_warm_up(s): ...@@ -18,8 +18,8 @@ def json_warm_up(s):
s += "The information about Hogwarts is in the following JSON format.\n" s += "The information about Hogwarts is in the following JSON format.\n"
with s.var_scope("json_output"): with s.var_scope("json_output"):
s += "{\n" s += "{\n"
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n" s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n" s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n" s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n" s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n" s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
...@@ -35,8 +35,8 @@ def json_decode(s, document): ...@@ -35,8 +35,8 @@ def json_decode(s, document):
s += "Here is the name, country, and symbol of the city in JSON format.\n" s += "Here is the name, country, and symbol of the city in JSON format.\n"
with s.var_scope("json_output"): with s.var_scope("json_output"):
s += "{\n" s += "{\n"
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n" s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n" s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n" s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n" s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n" s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
......
...@@ -72,7 +72,7 @@ def gen( ...@@ -72,7 +72,7 @@ def gen(
logprob_start_len: Optional[int] = None, logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None, top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None, return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None, dtype: Optional[Union[type, str]] = None,
choices: Optional[List[str]] = None, choices: Optional[List[str]] = None,
choices_method: Optional[ChoicesSamplingMethod] = None, choices_method: Optional[ChoicesSamplingMethod] = None,
regex: Optional[str] = None, regex: Optional[str] = None,
......
...@@ -195,7 +195,7 @@ def extend(reqs, model_runner): ...@@ -195,7 +195,7 @@ def extend(reqs, model_runner):
token_to_kv_pool=model_runner.token_to_kv_pool, token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None, tree_cache=None,
) )
batch.prepare_for_extend(model_runner.model_config.vocab_size, None) batch.prepare_for_extend(model_runner.model_config.vocab_size)
output = model_runner.forward(batch, ForwardMode.EXTEND) output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits) next_token_ids = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits, batch return next_token_ids, output.next_token_logits, batch
......
import json import json
import warnings
from typing import List, Optional from typing import List, Optional
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.choices import ( from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
ChoicesDecision,
ChoicesSamplingMethod,
token_length_normalized,
)
from sglang.lang.interpreter import StreamExecutor from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams from sglang.lang.ir import (
REGEX_BOOL,
REGEX_FLOAT,
REGEX_INT,
REGEX_STR,
SglSamplingParams,
)
from sglang.utils import http_request from sglang.utils import http_request
class RuntimeEndpoint(BaseBackend): class RuntimeEndpoint(BaseBackend):
def __init__( def __init__(
self, self,
base_url: str, base_url: str,
...@@ -95,32 +97,52 @@ class RuntimeEndpoint(BaseBackend): ...@@ -95,32 +97,52 @@ class RuntimeEndpoint(BaseBackend):
) )
self._assert_success(res) self._assert_success(res)
def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
if sampling_params.dtype is None:
return
if sampling_params.stop == ():
sampling_params.stop = []
dtype_regex = None
if sampling_params.dtype in ["int", int]:
dtype_regex = REGEX_INT
sampling_params.stop.extend([" ", "\n"])
elif sampling_params.dtype in ["float", float]:
dtype_regex = REGEX_FLOAT
sampling_params.stop.extend([" ", "\n"])
elif sampling_params.dtype in ["str", str]:
dtype_regex = REGEX_STR
elif sampling_params.dtype in ["bool", bool]:
dtype_regex = REGEX_BOOL
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
if dtype_regex is not None and sampling_params.regex is not None:
warnings.warn(
f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
)
sampling_params.regex = dtype_regex
def generate( def generate(
self, self,
s: StreamExecutor, s: StreamExecutor,
sampling_params: SglSamplingParams, sampling_params: SglSamplingParams,
): ):
if sampling_params.dtype is None: self._handle_dtype_to_regex(sampling_params)
data = { data = {
"text": s.text_, "text": s.text_,
"sampling_params": { "sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output, "skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(), **sampling_params.to_srt_kwargs(),
}, },
} }
elif sampling_params.dtype in [int, "int"]:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
}
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
for item in [ for item in [
"return_logprob", "return_logprob",
...@@ -151,27 +173,16 @@ class RuntimeEndpoint(BaseBackend): ...@@ -151,27 +173,16 @@ class RuntimeEndpoint(BaseBackend):
s: StreamExecutor, s: StreamExecutor,
sampling_params: SglSamplingParams, sampling_params: SglSamplingParams,
): ):
if sampling_params.dtype is None: self._handle_dtype_to_regex(sampling_params)
data = {
"text": s.text_, data = {
"sampling_params": { "text": s.text_,
"skip_special_tokens": global_config.skip_special_tokens_in_output, "sampling_params": {
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, "skip_special_tokens": global_config.skip_special_tokens_in_output,
**sampling_params.to_srt_kwargs(), "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
}, **sampling_params.to_srt_kwargs(),
} },
elif sampling_params.dtype in [int, "int"]: }
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
}
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
for item in [ for item in [
"return_logprob", "return_logprob",
......
...@@ -8,10 +8,10 @@ from typing import List, Optional, Union ...@@ -8,10 +8,10 @@ from typing import List, Optional, Union
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.lang.choices import ChoicesSamplingMethod from sglang.lang.choices import ChoicesSamplingMethod
REGEX_INT = r"[-+]?[0-9]+" REGEX_INT = r"[-+]?[0-9]+[ \n]*"
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+" REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
REGEX_BOOL = r"(True|False)" REGEX_BOOL = r"(True|False)"
REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -383,7 +383,7 @@ class ScheduleBatch: ...@@ -383,7 +383,7 @@ class ScheduleBatch:
return out_cache_loc return out_cache_loc
def batch_sampling_params(self, vocab_size, int_token_logit_bias): def batch_sampling_params(self, vocab_size):
device = "cuda" device = "cuda"
bs, reqs = self.batch_size(), self.reqs bs, reqs = self.batch_size(), self.reqs
self.temperatures = torch.tensor( self.temperatures = torch.tensor(
...@@ -419,15 +419,8 @@ class ScheduleBatch: ...@@ -419,15 +419,8 @@ class ScheduleBatch:
# Handle logit bias but only allocate when needed # Handle logit bias but only allocate when needed
self.logit_bias = None self.logit_bias = None
for i in range(bs):
if reqs[i].sampling_params.dtype == "int":
if self.logit_bias is None:
self.logit_bias = torch.zeros(
(bs, vocab_size), dtype=torch.float32, device=device
)
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): def prepare_for_extend(self, vocab_size: int):
bs = self.batch_size() bs = self.batch_size()
reqs = self.reqs reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
...@@ -466,7 +459,7 @@ class ScheduleBatch: ...@@ -466,7 +459,7 @@ class ScheduleBatch:
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.batch_sampling_params(vocab_size, int_token_logit_bias) self.batch_sampling_params(vocab_size)
def check_decode_mem(self): def check_decode_mem(self):
bs = self.batch_size() bs = self.batch_size()
......
...@@ -54,7 +54,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode ...@@ -54,7 +54,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_int_token_logit_bias,
is_multimodal_model, is_multimodal_model,
set_random_seed, set_random_seed,
suppress_other_loggers, suppress_other_loggers,
...@@ -132,9 +131,6 @@ class ModelTpServer: ...@@ -132,9 +131,6 @@ class ModelTpServer:
), ),
self.model_runner.req_to_token_pool.size - 1, self.model_runner.req_to_token_pool.size - 1,
) )
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
self.max_req_input_len = min( self.max_req_input_len = min(
self.model_config.context_len - 1, self.model_config.context_len - 1,
self.max_total_num_tokens - 1, self.max_total_num_tokens - 1,
...@@ -442,9 +438,7 @@ class ModelTpServer: ...@@ -442,9 +438,7 @@ class ModelTpServer:
def forward_prefill_batch(self, batch: ScheduleBatch): def forward_prefill_batch(self, batch: ScheduleBatch):
# Build batch tensors # Build batch tensors
batch.prepare_for_extend( batch.prepare_for_extend(self.model_config.vocab_size)
self.model_config.vocab_size, self.int_token_logit_bias
)
if self.model_runner.is_generation: if self.model_runner.is_generation:
# Forward and sample the next tokens # Forward and sample the next tokens
......
...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
......
...@@ -36,7 +36,6 @@ class SamplingParams: ...@@ -36,7 +36,6 @@ class SamplingParams:
ignore_eos: bool = False, ignore_eos: bool = False,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
dtype: Optional[str] = None,
regex: Optional[str] = None, regex: Optional[str] = None,
n: int = 1, n: int = 1,
) -> None: ) -> None:
...@@ -53,7 +52,6 @@ class SamplingParams: ...@@ -53,7 +52,6 @@ class SamplingParams:
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens
self.dtype = dtype
self.regex = regex self.regex = regex
self.n = n self.n = n
...@@ -63,8 +61,6 @@ class SamplingParams: ...@@ -63,8 +61,6 @@ class SamplingParams:
self.top_k = 1 self.top_k = 1
if self.top_k == -1: if self.top_k == -1:
self.top_k = 1 << 30 # whole vocabulary self.top_k = 1 << 30 # whole vocabulary
if self.dtype == "int":
self.stop_strs = [" ", "\n"]
def verify(self): def verify(self):
if self.temperature < 0.0: if self.temperature < 0.0:
......
...@@ -103,13 +103,13 @@ def test_decode_int(): ...@@ -103,13 +103,13 @@ def test_decode_int():
def test_decode_json_regex(): def test_decode_json_regex():
@sgl.function @sgl.function
def decode_json(s): def decode_json(s):
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
s += "Generate a JSON object to describe the basic city information of Paris.\n" s += "Generate a JSON object to describe the basic city information of Paris.\n"
with s.var_scope("json_output"): with s.var_scope("json_output"):
s += "{\n" s += "{\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n" s += ' "name": ' + sgl.gen(regex=REGEX_STR + ",") + "\n"
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n" s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
...@@ -359,6 +359,30 @@ def test_regex(): ...@@ -359,6 +359,30 @@ def test_regex():
assert re.match(regex, answer) assert re.match(regex, answer)
def test_dtype_gen():
@sgl.function
def dtype_gen(s):
s += "Q: What is the full name of DNS?\n"
s += "A: The full nams is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n"
s += "Q: Which year was DNS invented?\n"
s += "A: " + sgl.gen("int_res", dtype=int) + "\n"
s += "Q: What is the value of pi?\n"
s += "A: " + sgl.gen("float_res", dtype=float) + "\n"
s += "Q: Is the sky blue?\n"
s += "A: " + sgl.gen("bool_res", dtype=bool) + "\n"
state = dtype_gen.run()
try:
state["int_res"] = int(state["int_res"])
state["float_res"] = float(state["float_res"])
state["bool_res"] = bool(state["bool_res"])
# assert state["str_res"].startswith('"') and state["str_res"].endswith('"')
except ValueError:
print(state)
raise
def test_completion_speculative(): def test_completion_speculative():
@sgl.function(num_api_spec_tokens=64) @sgl.function(num_api_spec_tokens=64)
def gen_character_spec(s): def gen_character_spec(s):
......
import json
import unittest import unittest
import sglang as sgl import sglang as sgl
from sglang.test.test_programs import ( from sglang.test.test_programs import (
test_decode_int, test_decode_int,
test_decode_json_regex, test_decode_json_regex,
test_dtype_gen,
test_expert_answer, test_expert_answer,
test_few_shot_qa, test_few_shot_qa,
test_mt_bench, test_mt_bench,
...@@ -59,6 +59,9 @@ class TestSRTBackend(unittest.TestCase): ...@@ -59,6 +59,9 @@ class TestSRTBackend(unittest.TestCase):
def test_regex(self): def test_regex(self):
test_regex() test_regex()
def test_dtype_gen(self):
test_dtype_gen()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment