Unverified Commit 1aea19f6 authored by Rin Intachuen's avatar Rin Intachuen Committed by GitHub
Browse files

Input_embeds support (#2052)

parent 1f76fc6e
......@@ -11,21 +11,23 @@ The `/generate` endpoint accepts the following arguments in the JSON format.
class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids.
# The token ids for text; one can specify either text or input_ids
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
# The sampling_params. See descriptions below.
sampling_params: Union[List[Dict], Dict] = None
sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
# Whether to return logprobs.
return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob.
# If return logprobs, the start location in the prompt for returning logprobs.
# By default, this value is "-1", which means it will only return logprobs for output tokens.
logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return.
# If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False
......
......@@ -29,8 +29,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids.
# The token ids for text; one can specify either text or input_ids
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
......@@ -60,10 +62,16 @@ class GenerateReqInput:
] = None
def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
if (
self.text is None and self.input_ids is None and self.input_embeds is None
) or (
self.text is not None
and self.input_ids is not None
and self.input_embeds is not None
):
raise ValueError("Either text or input_ids should be provided.")
raise ValueError(
"Either text, input_ids or input_embeds should be provided."
)
# Derive the batch size
if self.text is not None:
......@@ -73,13 +81,21 @@ class GenerateReqInput:
else:
self.is_single = False
self.batch_size = len(self.text)
else:
self.input_embeds = None
elif self.input_ids is not None:
if isinstance(self.input_ids[0], int):
self.is_single = True
self.batch_size = 1
else:
self.is_single = False
self.batch_size = len(self.input_ids)
self.input_embeds = None
else:
if isinstance(self.input_embeds[0][0], float):
self.is_single = True
self.batch_size = 1
else:
self.batch_size = len(self.input_embeds)
# Handle parallel sampling
# When parallel sampling is used, we always treat the input as a batch.
......@@ -202,6 +218,8 @@ class TokenizedGenerateReqInput:
# LoRA related
lora_path: Optional[str] = None # None means just use the base model
# The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Session id info for continual prompting
session_id: Optional[str] = None
......@@ -218,6 +236,8 @@ class EmbeddingReqInput:
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None
# Dummy input embeds for compatibility
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
......
......@@ -178,6 +178,7 @@ class Req:
origin_input_ids: Tuple[int],
sampling_params: SamplingParams,
lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None,
):
# Input and output info
......@@ -191,6 +192,7 @@ class Req:
self.sampling_params = sampling_params
self.lora_path = lora_path
self.input_embeds = input_embeds
# Memory pool info
self.req_pool_idx = None
......@@ -448,6 +450,7 @@ class ScheduleBatch:
# Batched arguments to model runner
input_ids: torch.Tensor = None
input_embeds: torch.Tensor = None
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
# The output locations of the KV cache
......@@ -631,6 +634,9 @@ class ScheduleBatch:
req_pool_indices = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
input_embeds = []
pt = 0
for i, req in enumerate(reqs):
already_computed = (
req.extend_logprob_start_len + 1 + req.cached_tokens
......@@ -649,6 +655,11 @@ class ScheduleBatch:
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
# If input_embeds are available, store them
if req.input_embeds is not None:
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
......@@ -671,6 +682,12 @@ class ScheduleBatch:
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.input_embeds = (
torch.tensor(input_embeds).to(self.device, non_blocking=True)
if input_embeds
else None
)
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
......@@ -1053,6 +1070,7 @@ class ScheduleBatch:
encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info,
input_embeds=self.input_embeds,
)
def copy(self):
......@@ -1123,6 +1141,9 @@ class ModelWorkerBatch:
# Sampling info
sampling_info: SamplingBatchInfo
# The input Embeds
input_embeds: Optional[torch.tensor] = None
@triton.jit
def write_req_to_token_pool_triton(
......
......@@ -526,12 +526,20 @@ class Scheduler:
recv_req: TokenizedGenerateReqInput,
):
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
# Check if input_embeds is present and create dummy input_ids
if recv_req.input_embeds is not None:
# Generate fake input_ids based on the length of input_embeds
seq_length = len(recv_req.input_embeds)
fake_input_ids = [1] * seq_length
recv_req.input_ids = fake_input_ids
req = Req(
recv_req.rid,
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
)
req.tokenizer = self.tokenizer
if recv_req.session_id is not None:
......
......@@ -201,8 +201,18 @@ class TokenizerManager:
):
"""Tokenize one request."""
# Tokenize
input_embeds = None
input_text = obj.text
if obj.input_ids is None:
if obj.input_embeds is not None:
if not self.server_args.disable_radix_cache:
raise ValueError(
"input_embeds is provided while disable_radix_cache is False. "
"Please add `--disable-radix-cach` when you launch the server "
"if you want to use input_embeds as inputs."
)
input_embeds = obj.input_embeds
input_ids = obj.input_ids
elif obj.input_ids is None:
input_ids = self.tokenizer.encode(input_text)
else:
input_ids = obj.input_ids
......@@ -219,7 +229,7 @@ class TokenizerManager:
session_id = obj.session[0] if obj.session else None
session_rid = obj.session[1] if obj.session else None
if len(input_ids) >= self.context_len:
if obj.input_ids is not None and len(input_ids) >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
......@@ -242,7 +252,8 @@ class TokenizerManager:
logprob_start_len,
top_logprobs_num,
obj.stream,
obj.lora_path,
lora_path=obj.lora_path,
input_embeds=input_embeds,
session_id=session_id,
session_rid=session_rid,
)
......
......@@ -130,6 +130,9 @@ class ForwardBatch:
# For LoRA
lora_paths: Optional[List[str]] = None
# For input embeddings
input_embeds: Optional[torch.tensor] = None
# Sampling info
sampling_info: SamplingBatchInfo = None
......@@ -231,6 +234,7 @@ class ForwardBatch:
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
input_embeds=batch.input_embeds,
)
if ret.global_num_tokens is not None:
......
......@@ -606,9 +606,17 @@ class ModelRunner:
def forward_extend(self, forward_batch: ForwardBatch):
self.attn_backend.init_forward_metadata(forward_batch)
if self.is_generation:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
if forward_batch.input_embeds is None:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
else:
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
input_embeds=forward_batch.input_embeds.bfloat16(),
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(
......
......@@ -14,6 +14,7 @@ suites = {
"test_double_sparsity.py",
"test_embedding_openai_server.py",
"test_eval_accuracy_mini.py",
"test_input_embeddings.py",
"test_json_constrained.py",
"test_large_max_new_tokens.py",
"test_metrics.py",
......
import json
import unittest
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestInputEmbeds(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--disable-radix"],
)
cls.texts = [
"The capital of France is",
"What is the best time of year to visit Japan for cherry blossoms?",
]
def generate_input_embeddings(self, text):
"""Generate input embeddings for a given text."""
input_ids = self.tokenizer(text, return_tensors="pt")["input_ids"]
embeddings = self.ref_model.get_input_embeddings()(input_ids)
return embeddings.squeeze().tolist() # Convert tensor to a list for API use
def send_request(self, payload):
"""Send a POST request to the API and return the response."""
response = requests.post(
self.base_url + "/generate",
json=payload,
timeout=30, # Set a reasonable timeout for the API request
)
if response.status_code == 200:
return response.json()
return {
"error": f"Request failed with status {response.status_code}: {response.text}"
}
def test_text_based_response(self):
"""Print API response using text-based input."""
for text in self.texts:
payload = {
"model": self.model,
"text": text,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
response = self.send_request(payload)
print(
f"Text Input: {text}\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
)
def test_embedding_based_response(self):
"""Print API response using input embeddings."""
for text in self.texts:
embeddings = self.generate_input_embeddings(text)
payload = {
"model": self.model,
"input_embeds": embeddings,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
response = self.send_request(payload)
print(
f"Embeddings Input (for text '{text}'):\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
)
def test_compare_text_vs_embedding(self):
"""Print responses for both text-based and embedding-based inputs."""
for text in self.texts:
# Text-based payload
text_payload = {
"model": self.model,
"text": text,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
# Embedding-based payload
embeddings = self.generate_input_embeddings(text)
embed_payload = {
"model": self.model,
"input_embeds": embeddings,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
# Get responses
text_response = self.send_request(text_payload)
embed_response = self.send_request(embed_payload)
# Print responses
print(
f"Text Input: {text}\nText-Based Response: {json.dumps(text_response, indent=2)}\n"
)
print(
f"Embeddings Input (for text '{text}'):\nEmbedding-Based Response: {json.dumps(embed_response, indent=2)}\n{'-' * 80}"
)
self.assertEqual(text_response["text"], embed_response["text"])
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment