Unverified Commit 583d6af7 authored by Mick's avatar Mick Committed by GitHub
Browse files

example: add vlm to token in & out example (#3941)


Co-authored-by: default avatarzhaochenyang20 <zhaochen20@outlook.com>
parent e074d84e
......@@ -52,7 +52,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template).
* `is_embedding`: Set to true to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks.
* `revision`: Adjust if a specific version of the model should be used.
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/token_in_token_out_llm.py).
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/token_in_token_out/).
* `json_model_override_args`: Override model config with the provided JSON.
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
......
......@@ -9,15 +9,15 @@ SGLang provides a direct inference engine without the need for an HTTP server. T
## Examples
### 1. [Offline Batch Inference](./offline_batch_inference.py)
### [Offline Batch Inference](./offline_batch_inference.py)
In this example, we launch an SGLang engine and feed a batch of inputs for inference. If you provide a very large batch, the engine will intelligently schedule the requests to process efficiently and prevent OOM (Out of Memory) errors.
### 2. [Embedding Generation](./embedding.py)
### [Embedding Generation](./embedding.py)
In this example, we launch an SGLang engine and feed a batch of inputs for embedding generation.
### 3. [Custom Server](./custom_server.py)
### [Custom Server](./custom_server.py)
This example demonstrates how to create a custom server on top of the SGLang Engine. We use [Sanic](https://sanic.dev/en/) as an example. The server supports both non-streaming and streaming endpoints.
......@@ -43,3 +43,7 @@ curl -X POST http://localhost:8000/generate_stream -H "Content-Type: applicatio
```
This will send both non-streaming and streaming requests to the server.
### [Token-In-Token-Out for RLHF](./token_in_token_out)
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
......@@ -30,7 +30,7 @@ def main():
# Print the outputs.
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated token ids: {output['token_ids']}")
print(f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}")
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
......
import argparse
import dataclasses
from io import BytesIO
from typing import Tuple
import requests
from PIL import Image
from transformers import AutoProcessor
from sglang import Engine
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.server_args import ServerArgs
from sglang.test.test_utils import DEFAULT_IMAGE_URL
def get_input_ids(
server_args: ServerArgs, model_config: ModelConfig
) -> Tuple[list[int], list]:
chat_template = get_chat_template_by_model_path(model_config.model_path)
text = f"{chat_template.image_token}What is in this picture?"
images = [Image.open(BytesIO(requests.get(DEFAULT_IMAGE_URL).content))]
image_data = [DEFAULT_IMAGE_URL]
processor = AutoProcessor.from_pretrained(
model_config.model_path, trust_remote_code=server_args.trust_remote_code
)
inputs = processor(
text=[text],
images=images,
return_tensors="pt",
)
return inputs.input_ids[0].tolist(), image_data
def token_in_out_example(
server_args: ServerArgs,
):
input_ids, image_data = get_input_ids(
server_args,
ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
model_override_args=server_args.json_model_override_args,
),
)
backend = Engine(**dataclasses.asdict(server_args))
output = backend.generate(
input_ids=input_ids,
image_data=image_data,
sampling_params={
"temperature": 0.8,
"max_new_tokens": 32,
},
)
print("===============================")
print(f"Output token ids: ", output["output_ids"])
backend.shutdown()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = [
"--model-path=Qwen/Qwen2-VL-2B",
]
args = parser.parse_args(args=args)
server_args = ServerArgs.from_cli_args(args)
server_args.skip_tokenizer_init = True
token_in_out_example(server_args)
......@@ -40,7 +40,7 @@ class ModelConfig:
trust_remote_code: bool = True,
revision: Optional[str] = None,
context_length: Optional[int] = None,
model_override_args: Optional[dict] = None,
model_override_args: Optional[str] = None,
is_embedding: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
......
......@@ -42,7 +42,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchMultimodalDecodeReq,
BatchTokenIDOut,
CloseSessionReqInput,
FlushCacheReq,
......@@ -104,7 +103,6 @@ from sglang.srt.utils import (
crash_on_warnings,
get_bool_env_var,
get_zmq_socket,
kill_itself_when_parent_died,
pyspy_dump_schedulers,
set_gpu_proc_affinity,
set_random_seed,
......@@ -1199,7 +1197,6 @@ class Scheduler:
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
batch.output_ids = next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
......@@ -1480,7 +1477,6 @@ class Scheduler:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool.free_group_end()
......@@ -1580,11 +1576,11 @@ class Scheduler:
if req.top_logprobs_num > 0:
req.input_top_logprobs_val = [None]
req.input_top_logprobs_idx = [None]
assert len(req.temp_input_token_ids_logprobs_val) == len(
req.temp_input_token_ids_logprobs_idx
)
for val, idx in zip(
req.temp_input_top_logprobs_val,
req.temp_input_top_logprobs_idx,
strict=True,
req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx
):
req.input_top_logprobs_val.extend(val)
req.input_top_logprobs_idx.extend(idx)
......@@ -1779,7 +1775,6 @@ class Scheduler:
if rids:
if self.model_config.is_multimodal_gen:
raise NotImplementedError()
self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut(
rids,
......
......@@ -11,7 +11,7 @@ import math
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type, cast
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
import gguf
import huggingface_hub
......@@ -19,7 +19,7 @@ import numpy as np
import torch
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from sglang.srt.configs.device_config import DeviceConfig
......@@ -197,7 +197,7 @@ class DefaultModelLoader(BaseModelLoader):
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if "SGLANG_USE_MODELSCOPE" in os.environ:
if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
......
......@@ -43,10 +43,15 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Ins
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_SMALL_VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
DEFAULT_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
DEFAULT_VIDEO_URL = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
def is_in_ci():
"""Return whether it is in CI runner."""
......
......@@ -5,13 +5,18 @@ python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.run_decode_st
import json
import unittest
from io import BytesIO
import requests
from transformers import AutoTokenizer
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_IMAGE_URL,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_VLM_MODEL_NAME,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
......@@ -29,6 +34,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--skip-tokenizer-init", "--stream-output"],
)
cls.eos_token_id = [119690]
cls.tokenizer = AutoTokenizer.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
)
......@@ -45,9 +51,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
top_logprobs_num=0,
n=1,
):
input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][
0
].tolist()
input_ids = self.get_input_ids(prompt_text)
response = requests.post(
self.base_url + "/generate",
......@@ -104,7 +108,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
def run_decode_stream(self, return_logprob=False, top_logprobs_num=0, n=1):
max_new_tokens = 32
input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is
input_ids = self.get_input_ids("The capital of France is")
requests.post(self.base_url + "/flush_cache")
response = requests.post(
self.base_url + "/generate",
......@@ -114,7 +118,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [119690],
"stop_token_ids": self.eos_token_id,
},
"stream": False,
"return_logprob": return_logprob,
......@@ -125,6 +129,9 @@ class TestSkipTokenizerInit(unittest.TestCase):
ret = response.json()
print(json.dumps(ret))
output_ids = ret["output_ids"]
print("output from non-streaming request:")
print(output_ids)
print(self.tokenizer.decode(output_ids, skip_special_tokens=True))
requests.post(self.base_url + "/flush_cache")
response_stream = requests.post(
......@@ -135,7 +142,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [119690],
"stop_token_ids": self.eos_token_id,
},
"stream": True,
"return_logprob": return_logprob,
......@@ -143,13 +150,10 @@ class TestSkipTokenizerInit(unittest.TestCase):
"logprob_start_len": 0,
},
)
ret = response.json()
output_ids = ret["output_ids"]
print("output from non-streaming request:")
print(output_ids)
response_stream_json = []
for line in response_stream.iter_lines():
print(line)
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
response_stream_json.append(json.loads(line[6:]))
out_stream_ids = []
......@@ -157,6 +161,8 @@ class TestSkipTokenizerInit(unittest.TestCase):
out_stream_ids += x["output_ids"]
print("output from streaming request:")
print(out_stream_ids)
print(self.tokenizer.decode(out_stream_ids, skip_special_tokens=True))
assert output_ids == out_stream_ids
def test_simple_decode(self):
......@@ -175,6 +181,46 @@ class TestSkipTokenizerInit(unittest.TestCase):
def test_simple_decode_stream(self):
self.run_decode_stream()
def get_input_ids(self, prompt_text) -> list[int]:
input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][
0
].tolist()
return input_ids
class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
@classmethod
def setUpClass(cls):
cls.image_url = DEFAULT_IMAGE_URL
response = requests.get(cls.image_url)
cls.image = Image.open(BytesIO(response.content))
cls.model = DEFAULT_SMALL_VLM_MODEL_NAME
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model, use_fast=False)
cls.processor = AutoProcessor.from_pretrained(cls.model, trust_remote_code=True)
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--skip-tokenizer-init"],
)
cls.eos_token_id = [cls.tokenizer.eos_token_id]
def get_input_ids(self, _prompt_text) -> list[int]:
chat_template = get_chat_template_by_model_path(self.model)
text = f"{chat_template.image_token}What is in this picture?"
inputs = self.processor(
text=[text],
images=[self.image],
return_tensors="pt",
)
return inputs.input_ids[0].tolist()
def test_simple_decode_stream(self):
# TODO mick
pass
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment