Unverified Commit 357671e2 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Add examples for server token-in-token-out (#4103)


Co-authored-by: default avatarzhaochenyang20 <zhaochen20@outlook.com>
parent e70fa279
......@@ -52,7 +52,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template).
* `is_embedding`: Set to true to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks.
* `revision`: Adjust if a specific version of the model should be used.
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/token_in_token_out/).
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/).
* `json_model_override_args`: Override model config with the provided JSON.
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
......
......@@ -44,6 +44,6 @@ curl -X POST http://localhost:8000/generate_stream -H "Content-Type: applicatio
This will send both non-streaming and streaming requests to the server.
### [Token-In-Token-Out for RLHF](./token_in_token_out)
### [Token-In-Token-Out for RLHF](../token_in_token_out)
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
......@@ -29,8 +29,12 @@ def main():
outputs = llm.generate(input_ids=token_ids_list, sampling_params=sampling_params)
# Print the outputs.
for prompt, output in zip(prompts, outputs):
decode_output = tokenizer.decode(output["output_ids"])
print("===============================")
print(f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}")
print(
f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}\nGenerated text: {decode_output}"
)
print()
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
......
"""
Usage:
python token_in_token_out_llm_server.py
"""
import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import is_in_ci
from sglang.utils import print_highlight, terminate_process, wait_for_server
if is_in_ci():
from docs.backend.patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
def main():
# Launch the server
server_process, port = launch_server_cmd(
f"python -m sglang.launch_server --model-path {MODEL_PATH} --skip-tokenizer-init --host 0.0.0.0"
)
wait_for_server(f"http://localhost:{port}")
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = {"temperature": 0.8, "top_p": 0.95}
# Tokenize inputs
tokenizer = get_tokenizer(MODEL_PATH)
token_ids_list = [tokenizer.encode(prompt) for prompt in prompts]
json_data = {
"input_ids": token_ids_list,
"sampling_params": sampling_params,
}
response = requests.post(
f"http://localhost:{port}/generate",
json=json_data,
)
outputs = response.json()
for prompt, output in zip(prompts, outputs):
print("===============================")
decode_output = tokenizer.decode(output["output_ids"])
print(
f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}\nGenerated text: {decode_output}"
)
print()
terminate_process(server_process)
if __name__ == "__main__":
main()
"""
Usage:
python token_in_token_out_vlm_server.py
"""
from io import BytesIO
from typing import Tuple
import requests
from PIL import Image
from transformers import AutoProcessor
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import DEFAULT_IMAGE_URL, is_in_ci
from sglang.utils import print_highlight, terminate_process, wait_for_server
if is_in_ci():
from docs.backend.patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd
MODEL_PATH = "Qwen/Qwen2-VL-2B"
def get_input_ids() -> Tuple[list[int], list]:
chat_template = get_chat_template_by_model_path(MODEL_PATH)
text = f"{chat_template.image_token}What is in this picture?"
images = [Image.open(BytesIO(requests.get(DEFAULT_IMAGE_URL).content))]
image_data = [DEFAULT_IMAGE_URL]
processor = AutoProcessor.from_pretrained(MODEL_PATH)
inputs = processor(
text=[text],
images=images,
return_tensors="pt",
)
return inputs.input_ids[0].tolist(), image_data
def main():
# Launch the server
server_process, port = launch_server_cmd(
f"python -m sglang.launch_server --model-path {MODEL_PATH} --skip-tokenizer-init --host 0.0.0.0"
)
wait_for_server(f"http://localhost:{port}")
input_ids, image_data = get_input_ids()
sampling_params = {
"temperature": 0.8,
"max_new_tokens": 32,
}
json_data = {
"input_ids": input_ids,
"image_data": image_data,
"sampling_params": sampling_params,
}
response = requests.post(
f"http://localhost:{port}/generate",
json=json_data,
)
output = response.json()
print("===============================")
print(f"Output token ids: ", output["output_ids"])
terminate_process(server_process)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment