"docs/source/vscode:/vscode.git/clone" did not exist on "1db7588739ded8461e236ad3bc3fd23c35eec4cd"
Unverified Commit c4f9707e authored by Shi Shuai's avatar Shi Shuai Committed by GitHub
Browse files

Improve: Token-In Token-Out Usage for RLHF (#2843)

parent 197cbf9b
...@@ -348,6 +348,76 @@ ...@@ -348,6 +348,76 @@
"source": [ "source": [
"terminate_process(reward_process)" "terminate_process(reward_process)"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Skip Tokenizer and Detokenizer\n",
"\n",
"SGLang Runtime also supports skip tokenizer and detokenizer. This is useful in cases like integrating with RLHF workflow."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer_free_server_process = execute_shell_command(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010 --skip-tokenizer-init\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(\"http://localhost:30010\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B-Instruct\")\n",
"\n",
"input_text = \"What is the capital of France?\"\n",
"\n",
"input_tokens = tokenizer.encode(input_text)\n",
"print_highlight(f\"Input Text: {input_text}\")\n",
"print_highlight(f\"Tokenized Input: {input_tokens}\")\n",
"\n",
"response = requests.post(\n",
" \"http://localhost:30010/generate\",\n",
" json={\n",
" \"input_ids\": input_tokens,\n",
" \"sampling_params\": {\n",
" \"temperature\": 0,\n",
" \"max_new_tokens\": 256,\n",
" \"stop_token_ids\": [tokenizer.eos_token_id],\n",
" },\n",
" \"stream\": False,\n",
" },\n",
")\n",
"output = response.json()\n",
"output_tokens = output[\"token_ids\"]\n",
"\n",
"output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)\n",
"print_highlight(f\"Tokenized Output: {output_tokens}\")\n",
"print_highlight(f\"Decoded Output: {output_text}\")\n",
"print_highlight(f\"Output Text: {output['meta_info']['finish_reason']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(tokenizer_free_server_process)"
]
} }
], ],
"metadata": { "metadata": {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Structured Outputs (JSON, Regex, EBNF)" "# Structured Outputs"
] ]
}, },
{ {
...@@ -43,6 +43,10 @@ ...@@ -43,6 +43,10 @@
" print_highlight,\n", " print_highlight,\n",
")\n", ")\n",
"import openai\n", "import openai\n",
"import os\n",
"\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n",
"\n", "\n",
"server_process = execute_shell_command(\n", "server_process = execute_shell_command(\n",
" \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0 --grammar-backend xgrammar\"\n", " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0 --grammar-backend xgrammar\"\n",
......
...@@ -56,10 +56,10 @@ The core features include: ...@@ -56,10 +56,10 @@ The core features include:
references/hyperparameter_tuning.md references/hyperparameter_tuning.md
references/benchmark_and_profiling.md references/benchmark_and_profiling.md
references/custom_chat_template.md references/custom_chat_template.md
references/deepseek.md
references/llama_405B.md references/llama_405B.md
references/modelscope.md references/modelscope.md
references/contribution_guide.md references/contribution_guide.md
references/troubleshooting.md references/troubleshooting.md
references/faq.md references/faq.md
references/learn_more.md references/learn_more.md
references/deepseek.md
# DeepSeek Model Optimizations in SGLang # DeepSeek Model Optimizations
SGLang provides several optimizations specifically designed for the DeepSeek model to boost its inference speed. This document outlines current optimizations for DeepSeek. Additionally, the SGLang team is actively developing enhancements for [DeepSeek-V3](https://github.com/sgl-project/sglang/issues/2591). SGLang provides several optimizations specifically designed for the DeepSeek model to boost its inference speed. This document outlines current optimizations for DeepSeek. Additionally, the SGLang team is actively developing enhancements for [DeepSeek-V3](https://github.com/sgl-project/sglang/issues/2591).
...@@ -16,7 +16,9 @@ SGLang provides several optimizations specifically designed for the DeepSeek mod ...@@ -16,7 +16,9 @@ SGLang provides several optimizations specifically designed for the DeepSeek mod
Overall, with these optimizations, we have achieved up to a 7x acceleration in output throughput compared to the previous version. Overall, with these optimizations, we have achieved up to a 7x acceleration in output throughput compared to the previous version.
![Multi-head Latent Attention for DeepSeek Series Models](https://lmsys.org/images/blog/sglang_v0_3/deepseek_mla.svg) <p align="center">
<img src="https://lmsys.org/images/blog/sglang_v0_3/deepseek_mla.svg" alt="Multi-head Latent Attention for DeepSeek Series Models">
</p>
**Usage**: MLA optimization is enabled by defalut, to disable, use `--disable-mla`. **Usage**: MLA optimization is enabled by defalut, to disable, use `--disable-mla`.
...@@ -26,7 +28,9 @@ Overall, with these optimizations, we have achieved up to a 7x acceleration in o ...@@ -26,7 +28,9 @@ Overall, with these optimizations, we have achieved up to a 7x acceleration in o
**Description**: This optimization involves data parallelism (DP) for the MLA attention mechanism of DeepSeek Series Models, which allows for a significant reduction in the KV cache size, enabling larger batch sizes. Each DP worker independently handles different types of batches (prefill, decode, idle), which are then synchronized before and after processing through the Mixture-of-Experts (MoE) layer. **Description**: This optimization involves data parallelism (DP) for the MLA attention mechanism of DeepSeek Series Models, which allows for a significant reduction in the KV cache size, enabling larger batch sizes. Each DP worker independently handles different types of batches (prefill, decode, idle), which are then synchronized before and after processing through the Mixture-of-Experts (MoE) layer.
![Data Parallelism Attention for DeepSeek Series Models](https://lmsys.org/images/blog/sglang_v0_4/dp_attention.svg). <p align="center">
<img src="https://lmsys.org/images/blog/sglang_v0_4/dp_attention.svg" alt="Data Parallelism Attention for DeepSeek Series Models">
</p>
**Usage**: This optimization is aimed at improving throughput and should be used for scenarios with high QPS (Queries Per Second). Data Parallelism Attention optimization can be enabeld by `--enable-dp-attention` for DeepSeek Series Models. **Usage**: This optimization is aimed at improving throughput and should be used for scenarios with high QPS (Queries Per Second). Data Parallelism Attention optimization can be enabeld by `--enable-dp-attention` for DeepSeek Series Models.
......
...@@ -181,8 +181,6 @@ class DetokenizerManager: ...@@ -181,8 +181,6 @@ class DetokenizerManager:
finished_reasons=recv_obj.finished_reasons, finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs, output_strs=output_strs,
prompt_tokens=recv_obj.prompt_tokens, prompt_tokens=recv_obj.prompt_tokens,
origin_input_ids=recv_obj.origin_input_ids,
output_ids=recv_obj.output_ids,
completion_tokens=recv_obj.completion_tokens, completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens, cached_tokens=recv_obj.cached_tokens,
input_token_logprobs_val=recv_obj.input_token_logprobs_val, input_token_logprobs_val=recv_obj.input_token_logprobs_val,
......
...@@ -323,9 +323,7 @@ class BatchTokenIDOut: ...@@ -323,9 +323,7 @@ class BatchTokenIDOut:
decoded_texts: List[str] decoded_texts: List[str]
decode_ids: List[int] decode_ids: List[int]
read_offsets: List[int] read_offsets: List[int]
# Only used when --return-token-ids` is set # Only used when `--skip-tokenizer-init` is on
origin_input_ids: Optional[List[int]]
# Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
output_ids: Optional[List[int]] output_ids: Optional[List[int]]
# Detokenization configs # Detokenization configs
skip_special_tokens: List[bool] skip_special_tokens: List[bool]
...@@ -356,10 +354,6 @@ class BatchStrOut: ...@@ -356,10 +354,6 @@ class BatchStrOut:
# The output decoded strings # The output decoded strings
output_strs: List[str] output_strs: List[str]
# The token ids
origin_input_ids: Optional[List[int]]
output_ids: Optional[List[int]]
# Token counts # Token counts
# real input and output tokens can be get from # real input and output tokens can be get from
# origin_input_ids and output_ids by enabling --return_token_ids # origin_input_ids and output_ids by enabling --return_token_ids
......
...@@ -1253,7 +1253,6 @@ class Scheduler: ...@@ -1253,7 +1253,6 @@ class Scheduler:
decode_ids_list = [] decode_ids_list = []
read_offsets = [] read_offsets = []
output_ids = [] output_ids = []
origin_input_ids = []
skip_special_tokens = [] skip_special_tokens = []
spaces_between_special_tokens = [] spaces_between_special_tokens = []
...@@ -1305,14 +1304,8 @@ class Scheduler: ...@@ -1305,14 +1304,8 @@ class Scheduler:
decode_ids, read_offset = req.init_incremental_detokenize() decode_ids, read_offset = req.init_incremental_detokenize()
decode_ids_list.append(decode_ids) decode_ids_list.append(decode_ids)
read_offsets.append(read_offset) read_offsets.append(read_offset)
if self.skip_tokenizer_init or self.server_args.return_token_ids: if self.skip_tokenizer_init:
output_ids.append(req.output_ids) output_ids.append(req.output_ids)
else:
output_ids = None
if self.server_args.return_token_ids:
origin_input_ids.append(req.origin_input_ids)
else:
origin_input_ids = None
skip_special_tokens.append(req.sampling_params.skip_special_tokens) skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append( spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens req.sampling_params.spaces_between_special_tokens
...@@ -1344,7 +1337,6 @@ class Scheduler: ...@@ -1344,7 +1337,6 @@ class Scheduler:
decoded_texts, decoded_texts,
decode_ids_list, decode_ids_list,
read_offsets, read_offsets,
origin_input_ids,
output_ids, output_ids,
skip_special_tokens, skip_special_tokens,
spaces_between_special_tokens, spaces_between_special_tokens,
......
...@@ -663,13 +663,6 @@ class TokenizerManager: ...@@ -663,13 +663,6 @@ class TokenizerManager:
"text": recv_obj.output_strs[i], "text": recv_obj.output_strs[i],
"meta_info": meta_info, "meta_info": meta_info,
} }
if self.server_args.return_token_ids:
out_dict.update(
{
"input_ids": recv_obj.origin_input_ids[i],
"output_ids": recv_obj.output_ids[i],
}
)
elif isinstance(recv_obj, BatchTokenIDOut): elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = { out_dict = {
"token_ids": recv_obj.output_ids[i], "token_ids": recv_obj.output_ids[i],
......
...@@ -55,7 +55,6 @@ class ServerArgs: ...@@ -55,7 +55,6 @@ class ServerArgs:
is_embedding: bool = False is_embedding: bool = False
revision: Optional[str] = None revision: Optional[str] = None
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
return_token_ids: bool = False
# Port for the HTTP server # Port for the HTTP server
host: str = "127.0.0.1" host: str = "127.0.0.1"
...@@ -296,6 +295,11 @@ class ServerArgs: ...@@ -296,6 +295,11 @@ class ServerArgs:
"tokenizer if available, and 'slow' will " "tokenizer if available, and 'slow' will "
"always use the slow tokenizer.", "always use the slow tokenizer.",
) )
parser.add_argument(
"--skip-tokenizer-init",
action="store_true",
help="If set, skip init tokenizer and pass input_ids in generate request",
)
parser.add_argument( parser.add_argument(
"--load-format", "--load-format",
type=str, type=str,
...@@ -404,18 +408,6 @@ class ServerArgs: ...@@ -404,18 +408,6 @@ class ServerArgs:
"name, a tag name, or a commit id. If unspecified, will use " "name, a tag name, or a commit id. If unspecified, will use "
"the default version.", "the default version.",
) )
parser.add_argument(
"--skip-tokenizer-init",
action="store_true",
help="If set, skip init tokenizer and pass input_ids in generate request",
)
parser.add_argument(
"--return-token-ids",
action="store_true",
default=ServerArgs.return_token_ids,
help="Whether to return token IDs in the output, this may introduce additional overhead.",
)
# Memory and scheduling # Memory and scheduling
parser.add_argument( parser.add_argument(
"--mem-fraction-static", "--mem-fraction-static",
......
...@@ -45,7 +45,6 @@ suites = { ...@@ -45,7 +45,6 @@ suites = {
"test_vision_chunked_prefill.py", "test_vision_chunked_prefill.py",
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_session_control.py", "test_session_control.py",
"test_engine_token_ids.py",
], ],
"nightly": [ "nightly": [
"test_nightly_gsm8k_eval.py", "test_nightly_gsm8k_eval.py",
......
import unittest
from transformers import AutoTokenizer
import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class TestEngineTokenIds(unittest.TestCase):
def test_token_ids_in_generate(self):
llm = sgl.Engine(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, return_token_ids=True
)
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = {"temperature": 0, "top_p": 0.95}
outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
deocode_input = tokenizer.decode(
output["input_ids"], skip_special_tokens=True
)
assert (deocode_input in prompt) or (
prompt in deocode_input
), f"Decode input: {deocode_input} mismatch for: {prompt}"
deocode_output = tokenizer.decode(
output["output_ids"], skip_special_tokens=True
)
assert (deocode_output in output["text"]) or (
output["text"] in deocode_output
), f"Decode output: {deocode_output} mismatch for: {output['text']}"
llm.shutdown()
if __name__ == "__main__":
unittest.main()
"""
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample
"""
import json import json
import unittest import unittest
import requests import requests
from transformers import AutoTokenizer
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -15,35 +12,63 @@ from sglang.test.test_utils import ( ...@@ -15,35 +12,63 @@ from sglang.test.test_utils import (
popen_launch_server, popen_launch_server,
) )
_server_process = None
_base_url = None
_tokenizer = None
def setUpModule():
"""
Launch the server once before all tests and initialize the tokenizer.
"""
global _server_process, _base_url, _tokenizer
_server_process = popen_launch_server(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--skip-tokenizer-init"],
)
_base_url = DEFAULT_URL_FOR_TEST
_tokenizer = AutoTokenizer.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
)
print(">>> setUpModule: Server launched, tokenizer ready")
def tearDownModule():
"""
Terminate the server once after all tests have completed.
"""
global _server_process
if _server_process is not None:
kill_process_tree(_server_process.pid)
_server_process = None
print(">>> tearDownModule: Server terminated")
class TestSkipTokenizerInit(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--skip-tokenizer-init"],
)
@classmethod class TestSkipTokenizerInit(unittest.TestCase):
def tearDownClass(cls): def run_decode(
kill_process_tree(cls.process.pid) self,
prompt_text="The capital of France is",
max_new_tokens=32,
return_logprob=False,
top_logprobs_num=0,
n=1,
):
input_ids = _tokenizer(prompt_text, return_tensors="pt")["input_ids"][
0
].tolist()
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
max_new_tokens = 32
input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is
response = requests.post( response = requests.post(
self.base_url + "/generate", _base_url + "/generate",
json={ json={
"input_ids": input_ids, "input_ids": input_ids,
"sampling_params": { "sampling_params": {
"temperature": 0 if n == 1 else 0.5, "temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,
"n": n, "n": n,
"stop_token_ids": [119690], "stop_token_ids": [_tokenizer.eos_token_id],
}, },
"stream": False, "stream": False,
"return_logprob": return_logprob, "return_logprob": return_logprob,
...@@ -52,25 +77,37 @@ class TestSkipTokenizerInit(unittest.TestCase): ...@@ -52,25 +77,37 @@ class TestSkipTokenizerInit(unittest.TestCase):
}, },
) )
ret = response.json() ret = response.json()
print(json.dumps(ret)) print(json.dumps(ret, indent=2))
def assert_one_item(item): def assert_one_item(item):
self.assertEqual( if item["meta_info"]["finish_reason"]["type"] == "stop":
len(item["token_ids"]), item["meta_info"]["completion_tokens"] self.assertEqual(
) item["meta_info"]["finish_reason"]["matched"],
self.assertEqual(len(item["token_ids"]), max_new_tokens) _tokenizer.eos_token_id,
assert item["meta_info"]["prompt_tokens"] == len(input_ids) )
elif item["meta_info"]["finish_reason"]["type"] == "length":
if return_logprob: self.assertEqual(
assert len(item["meta_info"]["input_token_logprobs"]) == len( len(item["token_ids"]), item["meta_info"]["completion_tokens"]
input_ids )
), f'{len(item["meta_info"]["input_token_logprobs"])} vs. f{len(input_ids)}' self.assertEqual(len(item["token_ids"]), max_new_tokens)
assert len(item["meta_info"]["output_token_logprobs"]) == max_new_tokens self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
if return_logprob:
self.assertEqual(
len(item["meta_info"]["input_token_logprobs"]),
len(input_ids),
f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}',
)
self.assertEqual(
len(item["meta_info"]["output_token_logprobs"]),
max_new_tokens,
)
# Determine whether to assert a single item or multiple items based on n
if n == 1: if n == 1:
assert_one_item(ret) assert_one_item(ret)
else: else:
assert len(ret) == n self.assertEqual(len(ret), n)
for i in range(n): for i in range(n):
assert_one_item(ret[i]) assert_one_item(ret[i])
...@@ -84,10 +121,10 @@ class TestSkipTokenizerInit(unittest.TestCase): ...@@ -84,10 +121,10 @@ class TestSkipTokenizerInit(unittest.TestCase):
def test_logprob(self): def test_logprob(self):
for top_logprobs_num in [0, 3]: for top_logprobs_num in [0, 3]:
self.run_decode( self.run_decode(return_logprob=True, top_logprobs_num=top_logprobs_num)
return_logprob=True,
top_logprobs_num=top_logprobs_num, def test_eos_behavior(self):
) self.run_decode(max_new_tokens=256)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment