Unverified Commit d8476818 authored by Juwan Yoo's avatar Juwan Yoo Committed by GitHub
Browse files

feat: allow streaming for multi-prompt and/or parallel sampling (#1134)

parent df191254
...@@ -153,9 +153,6 @@ class TokenizerManager: ...@@ -153,9 +153,6 @@ class TokenizerManager:
async for response in self._handle_single_request(obj, request): async for response in self._handle_single_request(obj, request):
yield response yield response
else: else:
if hasattr(obj, "stream") and obj.stream:
raise ValueError("Do not support stream for batch mode.")
async for response in self._handle_batch_request(obj, request): async for response in self._handle_batch_request(obj, request):
yield response yield response
...@@ -311,6 +308,7 @@ class TokenizerManager: ...@@ -311,6 +308,7 @@ class TokenizerManager:
parallel_sample_num = 1 parallel_sample_num = 1
# First send out all requests # First send out all requests
generators = []
for i in range(batch_size): for i in range(batch_size):
for j in range(parallel_sample_num): for j in range(parallel_sample_num):
if j == 0 and parallel_sample_num != 1: if j == 0 and parallel_sample_num != 1:
...@@ -371,41 +369,47 @@ class TokenizerManager: ...@@ -371,41 +369,47 @@ class TokenizerManager:
state = ReqState([], False, event) state = ReqState([], False, event)
self.rid_to_state[rid] = state self.rid_to_state[rid] = state
# Then wait for all responses generators.append(
self._wait_for_response(
event,
state,
obj,
rid,
request,
index=index,
response_index=len(generators),
)
)
# Then process the responses based on streaming option
is_stream = hasattr(obj, "stream") and obj.stream
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
output_list = [] output_list = []
for i in range(batch_size):
for j in range(parallel_sample_num):
if j == 0 and parallel_sample_num != 1:
continue
index = i * parallel_sample_num + j
if parallel_sample_num != 1:
index += batch_size - 1 - i
rid = obj.rid[index]
state = self.rid_to_state[rid]
while True: while tasks:
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in done:
gen_index = tasks.index(task)
try: try:
await asyncio.wait_for(state.event.wait(), timeout=4) result = task.result()
break
except asyncio.TimeoutError: if is_stream:
if request is not None and await request.is_disconnected(): yield result
for rid in obj.rid:
self.abort_request(rid)
raise ValueError(f"Abort request {rid}")
continue
if self.is_generation:
output_list.append(
self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob[index],
obj.top_logprobs_num[index],
obj.return_text_in_logprobs,
)
)
else: else:
output_list.append(state.out_list[-1]) output_list.append(result)
assert state.finished
del self.rid_to_state[rid] tasks[gen_index] = asyncio.create_task(
generators[gen_index].__anext__()
)
except StopAsyncIteration:
del generators[gen_index]
del tasks[gen_index]
if not is_stream:
yield output_list yield output_list
def _validate_input_length(self, input_ids: List[int]): def _validate_input_length(self, input_ids: List[int]):
...@@ -437,12 +441,15 @@ class TokenizerManager: ...@@ -437,12 +441,15 @@ class TokenizerManager:
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
rid: str, rid: str,
request, request,
index: int = None,
response_index: int = 0,
): ):
while True: while True:
try: try:
await asyncio.wait_for(event.wait(), timeout=4) await asyncio.wait_for(event.wait(), timeout=4)
except asyncio.TimeoutError: except asyncio.TimeoutError:
if request is not None and await request.is_disconnected(): if request is not None and await request.is_disconnected():
for rid in [obj.rid] if obj.is_single else obj.rid:
self.abort_request(rid) self.abort_request(rid)
raise ValueError(f"Abort request {rid}") raise ValueError(f"Abort request {rid}")
continue continue
...@@ -450,13 +457,19 @@ class TokenizerManager: ...@@ -450,13 +457,19 @@ class TokenizerManager:
if self.is_generation: if self.is_generation:
out = self.convert_logprob_style( out = self.convert_logprob_style(
state.out_list[-1], state.out_list[-1],
obj.return_logprob, obj.return_logprob if index is None else obj.return_logprob[index],
obj.top_logprobs_num, (
obj.top_logprobs_num
if index is None
else obj.top_logprobs_num[index]
),
obj.return_text_in_logprobs, obj.return_text_in_logprobs,
) )
else: # isinstance(obj, EmbeddingReqInput) else: # isinstance(obj, EmbeddingReqInput)
out = state.out_list[-1] out = state.out_list[-1]
out["index"] = response_index
# Log requests # Log requests
if self.server_args.log_requests and state.finished: if self.server_args.log_requests and state.finished:
if obj.text is None: if obj.text is None:
......
...@@ -277,6 +277,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ...@@ -277,6 +277,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
request_data = json.loads(line) request_data = json.loads(line)
file_request_list.append(request_data) file_request_list.append(request_data)
body = request_data["body"] body = request_data["body"]
# Although streaming is supported for standalone completions, it is not supported in
# batch mode (multiple completions in single request).
if body.get("stream", False):
raise ValueError("Streaming requests are not supported in batch mode")
if end_point == "/v1/chat/completions": if end_point == "/v1/chat/completions":
all_requests.append(ChatCompletionRequest(**body)) all_requests.append(ChatCompletionRequest(**body))
elif end_point == "/v1/completions": elif end_point == "/v1/completions":
...@@ -592,27 +598,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -592,27 +598,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if adapted_request.stream: if adapted_request.stream:
async def generate_stream_resp(): async def generate_stream_resp():
stream_buffer = "" stream_buffers = {}
n_prev_token = 0 n_prev_tokens = {}
prompt_tokens = {}
completion_tokens = {}
try: try:
async for content in tokenizer_manager.generate_request( async for content in tokenizer_manager.generate_request(
adapted_request, raw_request adapted_request, raw_request
): ):
index = content["index"]
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
text = content["text"] text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"] prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"] completion_tokens[index] = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk if not stream_buffer: # The first chunk
if request.echo: if request.echo:
if isinstance(request.prompt, str): if isinstance(request.prompt, str):
# for the case of single str prompts # for the case of single str prompts
prompts = request.prompt prompts = request.prompt
elif isinstance(request.prompt, list) and isinstance( elif isinstance(request.prompt, list):
request.prompt[0], int if isinstance(request.prompt[0], str):
): # for the case of multiple str prompts
prompts = request.prompt[index // request.n]
elif isinstance(request.prompt[0], int):
# for the case of single token ids prompt
prompts = tokenizer_manager.tokenizer.decode( prompts = tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True request.prompt, skip_special_tokens=True
) )
elif isinstance(request.prompt[0], list) and isinstance(
request.prompt[0][0], int
):
# for the case of multiple token ids prompts
prompts = tokenizer_manager.tokenizer.decode(
request.prompt[index // request.n],
skip_special_tokens=True,
)
# Prepend prompt in response text. # Prepend prompt in response text.
text = prompts + text text = prompts + text
...@@ -649,7 +673,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -649,7 +673,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
delta = text[len(stream_buffer) :] delta = text[len(stream_buffer) :]
stream_buffer = stream_buffer + delta stream_buffer = stream_buffer + delta
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=0, index=index,
text=delta, text=delta,
logprobs=logprobs, logprobs=logprobs,
finish_reason=format_finish_reason( finish_reason=format_finish_reason(
...@@ -662,12 +686,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -662,12 +686,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
choices=[choice_data], choices=[choice_data],
model=request.model, model=request.model,
) )
stream_buffers[index] = stream_buffer
n_prev_tokens[index] = n_prev_token
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
if request.stream_options and request.stream_options.include_usage: if request.stream_options and request.stream_options.include_usage:
total_prompt_tokens = sum(
tokens
for i, tokens in prompt_tokens.items()
if i % request.n == 0
)
total_completion_tokens = sum(
tokens for tokens in completion_tokens.values()
)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=total_prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=total_completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=total_prompt_tokens + total_completion_tokens,
) )
final_usage_chunk = CompletionStreamResponse( final_usage_chunk = CompletionStreamResponse(
...@@ -914,16 +950,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -914,16 +950,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
if adapted_request.stream: if adapted_request.stream:
async def generate_stream_resp(): async def generate_stream_resp():
is_first = True is_firsts = {}
stream_buffers = {}
stream_buffer = "" n_prev_tokens = {}
n_prev_token = 0 prompt_tokens = {}
completion_tokens = {}
try: try:
async for content in tokenizer_manager.generate_request( async for content in tokenizer_manager.generate_request(
adapted_request, raw_request adapted_request, raw_request
): ):
prompt_tokens = content["meta_info"]["prompt_tokens"] index = content["index"]
completion_tokens = content["meta_info"]["completion_tokens"]
is_first = is_firsts.get(index, True)
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
if request.logprobs: if request.logprobs:
logprobs = to_openai_style_logprobs( logprobs = to_openai_style_logprobs(
output_token_logprobs=content["meta_info"][ output_token_logprobs=content["meta_info"][
...@@ -973,7 +1016,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -973,7 +1016,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
# First chunk with role # First chunk with role
is_first = False is_first = False
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=index,
delta=DeltaMessage(role="assistant"), delta=DeltaMessage(role="assistant"),
finish_reason=format_finish_reason( finish_reason=format_finish_reason(
content["meta_info"]["finish_reason"] content["meta_info"]["finish_reason"]
...@@ -991,7 +1034,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -991,7 +1034,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
delta = text[len(stream_buffer) :] delta = text[len(stream_buffer) :]
stream_buffer = stream_buffer + delta stream_buffer = stream_buffer + delta
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=index,
delta=DeltaMessage(content=delta), delta=DeltaMessage(content=delta),
finish_reason=format_finish_reason( finish_reason=format_finish_reason(
content["meta_info"]["finish_reason"] content["meta_info"]["finish_reason"]
...@@ -1003,12 +1046,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -1003,12 +1046,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choices=[choice_data], choices=[choice_data],
model=request.model, model=request.model,
) )
is_firsts[index] = is_first
stream_buffers[index] = stream_buffer
n_prev_tokens[index] = n_prev_token
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
if request.stream_options and request.stream_options.include_usage: if request.stream_options and request.stream_options.include_usage:
total_prompt_tokens = sum(
tokens
for i, tokens in prompt_tokens.items()
if i % request.n == 0
)
total_completion_tokens = sum(
tokens for tokens in completion_tokens.values()
)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=total_prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=total_completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=total_prompt_tokens + total_completion_tokens,
) )
final_usage_chunk = ChatCompletionStreamResponse( final_usage_chunk = ChatCompletionStreamResponse(
......
...@@ -85,13 +85,26 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -85,13 +85,26 @@ class TestOpenAIServer(unittest.TestCase):
assert response.usage.completion_tokens > 0 assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def run_completion_stream(self, echo, logprobs, token_input): def run_completion_stream(
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is" prompt = "The capital of France is"
if token_input: if token_input:
prompt_arg = self.tokenizer.encode(prompt) prompt_input = self.tokenizer.encode(prompt)
num_prompt_tokens = len(prompt_input)
else: else:
prompt_arg = prompt prompt_input = prompt
num_prompt_tokens = len(self.tokenizer.encode(prompt))
if use_list_input:
prompt_arg = [prompt_input, prompt_input]
num_choices = len(prompt_arg)
num_prompt_tokens *= 2
else:
prompt_arg = prompt_input
num_choices = 1
generator = client.completions.create( generator = client.completions.create(
model=self.model, model=self.model,
prompt=prompt_arg, prompt=prompt_arg,
...@@ -101,9 +114,10 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -101,9 +114,10 @@ class TestOpenAIServer(unittest.TestCase):
logprobs=logprobs, logprobs=logprobs,
stream=True, stream=True,
stream_options={"include_usage": True}, stream_options={"include_usage": True},
n=parallel_sample_num,
) )
first = True is_firsts = {}
for response in generator: for response in generator:
usage = response.usage usage = response.usage
if usage is not None: if usage is not None:
...@@ -111,10 +125,14 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -111,10 +125,14 @@ class TestOpenAIServer(unittest.TestCase):
assert usage.completion_tokens > 0 assert usage.completion_tokens > 0
assert usage.total_tokens > 0 assert usage.total_tokens > 0
continue continue
index = response.choices[0].index
is_first = is_firsts.get(index, True)
if logprobs: if logprobs:
assert response.choices[0].logprobs assert response.choices[0].logprobs
assert isinstance(response.choices[0].logprobs.tokens[0], str) assert isinstance(response.choices[0].logprobs.tokens[0], str)
if not (first and echo): if not (is_first and echo):
assert isinstance( assert isinstance(
response.choices[0].logprobs.top_logprobs[0], dict response.choices[0].logprobs.top_logprobs[0], dict
) )
...@@ -125,15 +143,20 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -125,15 +143,20 @@ class TestOpenAIServer(unittest.TestCase):
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0 assert ret_num_top_logprobs > 0
if first: if is_first:
if echo: if echo:
assert response.choices[0].text.startswith( assert response.choices[0].text.startswith(
prompt prompt
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}" ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
first = False is_firsts[index] = False
assert response.id assert response.id
assert response.created assert response.created
for index in [i for i in range(parallel_sample_num * num_choices)]:
assert not is_firsts.get(
index, True
), f"index {index} is not found in the response"
def run_chat_completion(self, logprobs, parallel_sample_num): def run_chat_completion(self, logprobs, parallel_sample_num):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create( response = client.chat.completions.create(
...@@ -172,7 +195,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -172,7 +195,7 @@ class TestOpenAIServer(unittest.TestCase):
assert response.usage.completion_tokens > 0 assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def run_chat_completion_stream(self, logprobs): def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
generator = client.chat.completions.create( generator = client.chat.completions.create(
model=self.model, model=self.model,
...@@ -185,9 +208,10 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -185,9 +208,10 @@ class TestOpenAIServer(unittest.TestCase):
top_logprobs=logprobs, top_logprobs=logprobs,
stream=True, stream=True,
stream_options={"include_usage": True}, stream_options={"include_usage": True},
n=parallel_sample_num,
) )
is_first = True is_firsts = {}
for response in generator: for response in generator:
usage = response.usage usage = response.usage
if usage is not None: if usage is not None:
...@@ -196,11 +220,12 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -196,11 +220,12 @@ class TestOpenAIServer(unittest.TestCase):
assert usage.total_tokens > 0 assert usage.total_tokens > 0
continue continue
index = response.choices[0].index
data = response.choices[0].delta data = response.choices[0].delta
if is_first: if is_firsts.get(index, True):
data.role == "assistant" assert data.role == "assistant"
is_first = False is_firsts[index] = False
continue continue
if logprobs: if logprobs:
...@@ -222,6 +247,11 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -222,6 +247,11 @@ class TestOpenAIServer(unittest.TestCase):
assert response.id assert response.id
assert response.created assert response.created
for index in [i for i in range(parallel_sample_num)]:
assert not is_firsts.get(
index, True
), f"index {index} is not found in the response"
def run_batch(self, mode): def run_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
if mode == "completion": if mode == "completion":
...@@ -320,7 +350,9 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -320,7 +350,9 @@ class TestOpenAIServer(unittest.TestCase):
f"Batch job status: {batch_job.status}...trying again in 3 seconds..." f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
) )
batch_job = client.batches.retrieve(batch_job.id) batch_job = client.batches.retrieve(batch_job.id)
assert batch_job.status == "completed" assert (
batch_job.status == "completed"
), f"Batch job status is not completed: {batch_job.status}"
assert batch_job.request_counts.completed == len(content) assert batch_job.request_counts.completed == len(content)
assert batch_job.request_counts.failed == 0 assert batch_job.request_counts.failed == 0
assert batch_job.request_counts.total == len(content) assert batch_job.request_counts.total == len(content)
...@@ -353,8 +385,16 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -353,8 +385,16 @@ class TestOpenAIServer(unittest.TestCase):
# parallel sampling adn list input are not supported in streaming mode # parallel sampling adn list input are not supported in streaming mode
for echo in [False, True]: for echo in [False, True]:
for logprobs in [None, 5]: for logprobs in [None, 5]:
for use_list_input in [True, False]:
for parallel_sample_num in [1, 2]:
for token_input in [False, True]: for token_input in [False, True]:
self.run_completion_stream(echo, logprobs, token_input) self.run_completion_stream(
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
)
def test_chat_completion(self): def test_chat_completion(self):
for logprobs in [None, 5]: for logprobs in [None, 5]:
...@@ -363,7 +403,8 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -363,7 +403,8 @@ class TestOpenAIServer(unittest.TestCase):
def test_chat_completion_stream(self): def test_chat_completion_stream(self):
for logprobs in [None, 5]: for logprobs in [None, 5]:
self.run_chat_completion_stream(logprobs) for parallel_sample_num in [1, 2]:
self.run_chat_completion_stream(logprobs, parallel_sample_num)
def test_batch(self): def test_batch(self):
for mode in ["completion", "chat"]: for mode in ["completion", "chat"]:
......
...@@ -23,7 +23,12 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -23,7 +23,12 @@ class TestSRTEndpoint(unittest.TestCase):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid)
def run_decode( def run_decode(
self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1 self,
return_logprob=False,
top_logprobs_num=0,
return_text=False,
n=1,
stream=False,
): ):
response = requests.post( response = requests.post(
self.base_url + "/generate", self.base_url + "/generate",
...@@ -34,14 +39,21 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -34,14 +39,21 @@ class TestSRTEndpoint(unittest.TestCase):
"max_new_tokens": 32, "max_new_tokens": 32,
"n": n, "n": n,
}, },
"stream": False, "stream": stream,
"return_logprob": return_logprob, "return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num, "top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text, "return_text_in_logprobs": return_text,
"logprob_start_len": 0, "logprob_start_len": 0,
}, },
) )
print(json.dumps(response.json())) if not stream:
response_json = response.json()
else:
response_json = []
for line in response.iter_lines():
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
response_json.append(json.loads(line[6:]))
print(json.dumps(response_json))
print("=" * 100) print("=" * 100)
def test_simple_decode(self): def test_simple_decode(self):
...@@ -50,6 +62,9 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -50,6 +62,9 @@ class TestSRTEndpoint(unittest.TestCase):
def test_parallel_sample(self): def test_parallel_sample(self):
self.run_decode(n=3) self.run_decode(n=3)
def test_parallel_sample_stream(self):
self.run_decode(n=3, stream=True)
def test_logprob(self): def test_logprob(self):
for top_logprobs_num in [0, 3]: for top_logprobs_num in [0, 3]:
for return_text in [True, False]: for return_text in [True, False]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment