Unverified Commit 04c0b214 authored by Shannon Shen's avatar Shannon Shen Committed by GitHub
Browse files

Allow `input_ids` in the input of the `/generate` endpoint (#363)

parent 6e09cf6a
...@@ -30,7 +30,7 @@ if __name__ == "__main__": ...@@ -30,7 +30,7 @@ if __name__ == "__main__":
response = requests.post( response = requests.post(
url + "/generate", url + "/generate",
json={ json={
"text": f"{a}, ", "input_ids": [[1,2,3], [1,2,3]],
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,
......
...@@ -8,6 +8,8 @@ The `/generate` endpoint accepts the following arguments in the JSON format. ...@@ -8,6 +8,8 @@ The `/generate` endpoint accepts the following arguments in the JSON format.
class GenerateReqInput: class GenerateReqInput:
# The input prompt # The input prompt
text: Union[List[str], str] text: Union[List[str], str]
# The token ids for text; one can either specify text or input_ids
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The image input # The image input
image_data: Optional[Union[List[str], str]] = None image_data: Optional[Union[List[str], str]] = None
# The sampling_params # The sampling_params
......
...@@ -8,7 +8,9 @@ from sglang.srt.sampling_params import SamplingParams ...@@ -8,7 +8,9 @@ from sglang.srt.sampling_params import SamplingParams
@dataclass @dataclass
class GenerateReqInput: class GenerateReqInput:
# The input prompt # The input prompt
text: Union[List[str], str] text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The image input # The image input
image_data: Optional[Union[List[str], str]] = None image_data: Optional[Union[List[str], str]] = None
# The sampling_params # The sampling_params
...@@ -28,7 +30,17 @@ class GenerateReqInput: ...@@ -28,7 +30,17 @@ class GenerateReqInput:
# TODO: make all parameters a Union[List[T], T] to allow for batched requests # TODO: make all parameters a Union[List[T], T] to allow for batched requests
def post_init(self): def post_init(self):
is_single = isinstance(self.text, str)
if self.text is None:
assert self.input_ids is not None, "Either text or input_ids should be provided"
else:
assert self.input_ids is None, "Either text or input_ids should be provided"
if self.text is not None:
is_single = isinstance(self.text, str)
else:
is_single = isinstance(self.input_ids[0], int)
self.is_single = is_single
if is_single: if is_single:
if self.sampling_params is None: if self.sampling_params is None:
...@@ -42,7 +54,7 @@ class GenerateReqInput: ...@@ -42,7 +54,7 @@ class GenerateReqInput:
if self.top_logprobs_num is None: if self.top_logprobs_num is None:
self.top_logprobs_num = 0 self.top_logprobs_num = 0
else: else:
num = len(self.text) num = len(self.text) if self.text is not None else len(self.input_ids)
if self.image_data is None: if self.image_data is None:
self.image_data = [None] * num self.image_data = [None] * num
......
...@@ -85,6 +85,9 @@ class Req: ...@@ -85,6 +85,9 @@ class Req:
) )
if first_token.startswith("▁"): if first_token.startswith("▁"):
old_output_str = " " + old_output_str old_output_str = " " + old_output_str
if self.input_text is None:
# TODO(lmzheng): This can be wrong. Check with Liangsheng.
self.input_text = self.tokenizer.decode(self.input_ids)
new_input_string = ( new_input_string = (
self.input_text self.input_text
+ self.output_and_jump_forward_str + self.output_and_jump_forward_str
......
...@@ -147,11 +147,15 @@ class TokenizerManager: ...@@ -147,11 +147,15 @@ class TokenizerManager:
if self.to_create_loop: if self.to_create_loop:
await self.create_handle_loop() await self.create_handle_loop()
is_single = isinstance(obj.text, str) is_single = obj.is_single
if is_single: if is_single:
rid = obj.rid rid = obj.rid
input_ids = self.tokenizer.encode(obj.text)
if obj.input_ids is None:
input_ids = self.tokenizer.encode(obj.text)
else:
input_ids = obj.input_ids
sampling_params = SamplingParams(**obj.sampling_params) sampling_params = SamplingParams(**obj.sampling_params)
if sampling_params.max_new_tokens != 0: if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer) sampling_params.normalize(self.tokenizer)
...@@ -204,10 +208,22 @@ class TokenizerManager: ...@@ -204,10 +208,22 @@ class TokenizerManager:
event.clear() event.clear()
else: else:
assert obj.stream is False assert obj.stream is False
bs = len(obj.text)
if obj.input_ids is None:
bs = len(obj.text)
else:
bs = len(obj.input_ids)
for i in range(bs): for i in range(bs):
rid = obj.rid[i] rid = obj.rid[i]
input_ids = self.tokenizer.encode(obj.text[i])
if obj.input_ids is None:
input_text = obj.text[i]
input_ids = self.tokenizer.encode(obj.text[i])
else:
input_text = None
input_ids = obj.input_ids[i]
sampling_params = SamplingParams(**obj.sampling_params[i]) sampling_params = SamplingParams(**obj.sampling_params[i])
if sampling_params.max_new_tokens != 0: if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer) sampling_params.normalize(self.tokenizer)
...@@ -220,7 +236,7 @@ class TokenizerManager: ...@@ -220,7 +236,7 @@ class TokenizerManager:
) )
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
rid=rid, rid=rid,
input_text=obj.text[i], input_text=input_text,
input_ids=input_ids, input_ids=input_ids,
pixel_values=pixel_values, pixel_values=pixel_values,
image_hash=image_hash, image_hash=image_hash,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment