Unverified Commit 49c5e0ec authored by yichuan~'s avatar yichuan~ Committed by GitHub
Browse files

Add support for OpenAI API parallel sampling (#640)

parent ec2150b2
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
# Text completion
response = client.completions.create(
model="default",
prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little",
n=1,
temperature=0.8,
max_tokens=32,
)
print(response)
# Text completion
response = client.completions.create(
model="default",
prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little",
n=3,
temperature=0.8,
max_tokens=32,
)
print(response)
# Text completion
response = client.completions.create(
model="default",
prompt=["The name of the famous soccer player is ", "The capital of US is"],
n=1,
temperature=0.8,
max_tokens=32,
)
print(response)
# Text completion
response = client.completions.create(
model="default",
prompt=["The name of the famous soccer player is ", "The capital of US is"],
n=3,
temperature=0.8,
max_tokens=32,
)
print(response)
# Text completion
response = client.completions.create(
model="default",
prompt=[
"The capital of France is",
"The capital of Germany is",
"The capital of US is",
],
n=3,
temperature=0.8,
max_tokens=32,
)
print(response)
# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.8,
max_tokens=64,
logprobs=True,
n=4,
)
print(response)
...@@ -40,7 +40,9 @@ class GenerateReqInput: ...@@ -40,7 +40,9 @@ class GenerateReqInput:
self.text is not None and self.input_ids is not None self.text is not None and self.input_ids is not None
): ):
raise ValueError("Either text or input_ids should be provided.") raise ValueError("Either text or input_ids should be provided.")
if "n" in self.sampling_params and self.sampling_params["n"] != 1:
is_single = False
else:
if self.text is not None: if self.text is not None:
is_single = isinstance(self.text, str) is_single = isinstance(self.text, str)
else: else:
...@@ -59,7 +61,22 @@ class GenerateReqInput: ...@@ -59,7 +61,22 @@ class GenerateReqInput:
if self.top_logprobs_num is None: if self.top_logprobs_num is None:
self.top_logprobs_num = 0 self.top_logprobs_num = 0
else: else:
parallel_sample_num = self.sampling_params.get("n", 1)
if parallel_sample_num != 1:
# parallel sampling +1 represents the original prefill stage
num = parallel_sample_num + 1
if isinstance(self.text, List):
## suppot batch operation
self.batch_size = len(self.text)
num = num * len(self.text)
else:
self.batch_size = 1
else:
## support select operation
num = len(self.text) if self.text is not None else len(self.input_ids) num = len(self.text) if self.text is not None else len(self.input_ids)
self.batch_size = num
if self.image_data is None: if self.image_data is None:
self.image_data = [None] * num self.image_data = [None] * num
......
...@@ -122,54 +122,217 @@ class TokenizerManager: ...@@ -122,54 +122,217 @@ class TokenizerManager:
obj.post_init() obj.post_init()
is_single = obj.is_single is_single = obj.is_single
if is_single: if is_single:
rid = obj.rid async for response in self._handle_single_request(obj, request):
yield response
else:
if obj.stream:
raise ValueError("Do not support stream for batch mode.")
async for response in self._handle_batch_request(obj, request):
yield response
async def _handle_single_request(self, obj, request, index=None, is_prefill=False):
if is_prefill:
if isinstance(obj.text, list):
input_text = obj.text[index]
rid = obj.rid[index]
else:
input_text = obj.text
rid = obj.rid[0]
input_ids = self.tokenizer.encode(input_text)
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data[0]
)
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
else:
rid = obj.rid if index is None else obj.rid[index]
input_text = obj.text if index is None else obj.text[index]
input_ids = (
self.tokenizer.encode(input_text)
if obj.input_ids is None
else obj.input_ids
)
if index is not None and obj.input_ids:
input_ids = obj.input_ids[index]
self._validate_input_length(input_ids)
sampling_params = self._get_sampling_params(
obj.sampling_params if index is None else obj.sampling_params[index]
)
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data if index is None else obj.image_data[index]
)
return_logprob = (
obj.return_logprob if index is None else obj.return_logprob[index]
)
logprob_start_len = (
obj.logprob_start_len if index is None else obj.logprob_start_len[index]
)
top_logprobs_num = (
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
)
tokenized_obj = TokenizedGenerateReqInput(
rid,
input_text,
input_ids,
pixel_values,
image_hash,
image_size,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
obj.stream,
)
self.send_to_router.send_pyobj(tokenized_obj)
event = asyncio.Event()
state = ReqState([], False, event)
self.rid_to_state[rid] = state
if is_prefill == False:
async for response in self._wait_for_response(
event, state, obj, rid, request
):
yield response
else:
await self._wait_for_prefill_response(event, state, obj, request, rid)
yield input_ids
async def _handle_batch_request(self, obj, request):
batch_size = obj.batch_size
parallel_sample_num = obj.sampling_params[0].get("n", 1)
if parallel_sample_num != 1:
## send prefill requests
parallel_sample_num += 1
input_id_result = [] if obj.input_ids is None else None
for i in range(batch_size):
async for input_id in self._handle_single_request(
obj, request, index=i, is_prefill=True
):
if input_id_result is not None:
input_id_result.append(input_id)
pass
if len(input_id_result) > 1 and input_id_result is not None:
obj.input_ids = input_id_result
elif input_id_result is not None:
obj.input_ids = input_id_result[0]
# First send out all requests
for i in range(batch_size):
for j in range(parallel_sample_num):
if j == 0 and parallel_sample_num != 1:
continue
index = i * parallel_sample_num + j
if parallel_sample_num != 1:
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
index += batch_size - 1 - i
rid = obj.rid[index]
if parallel_sample_num == 1:
## select operation
if obj.input_ids is None: if obj.input_ids is None:
input_ids = self.tokenizer.encode(obj.text) input_text = obj.text[i]
input_ids = self.tokenizer.encode(obj.text[i])
else: else:
input_text = None
input_ids = obj.input_ids[i]
else:
if batch_size == 1:
input_text = obj.text
input_ids = obj.input_ids input_ids = obj.input_ids
else:
input_text = obj.text[i]
input_ids = obj.input_ids[i]
sampling_params = self._get_sampling_params(obj.sampling_params[index])
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data[index]
)
tokenized_obj = TokenizedGenerateReqInput(
rid,
input_text,
input_ids,
pixel_values,
image_hash,
image_size,
sampling_params,
obj.return_logprob[index],
obj.logprob_start_len[index],
obj.top_logprobs_num[index],
obj.stream,
)
self.send_to_router.send_pyobj(tokenized_obj)
event = asyncio.Event()
state = ReqState([], False, event)
self.rid_to_state[rid] = state
# Then wait for all responses
output_list = []
for i in range(batch_size):
for j in range(parallel_sample_num):
if j == 0 and parallel_sample_num != 1:
continue
index = i * parallel_sample_num + j
if parallel_sample_num != 1:
index += batch_size - 1 - i
rid = obj.rid[index]
state = self.rid_to_state[rid]
while True:
try:
await asyncio.wait_for(state.event.wait(), timeout=4)
break
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
for rid in obj.rid:
self.abort_request(rid)
raise ValueError(f"Abort request {rid}")
continue
output_list.append(
self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob[index],
obj.top_logprobs_num[index],
obj.return_text_in_logprobs,
)
)
assert state.finished
del self.rid_to_state[rid]
yield output_list
def _validate_input_length(self, input_ids):
if len(input_ids) >= self.context_len: if len(input_ids) >= self.context_len:
raise ValueError( raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the " f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)." f"model's context length ({self.context_len} tokens)."
) )
sampling_params = SamplingParams(**obj.sampling_params) def _get_sampling_params(self, sampling_params_data, max_new_tokens=None):
sampling_params = SamplingParams(**sampling_params_data)
if max_new_tokens is not None:
sampling_params.max_new_tokens = max_new_tokens
if sampling_params.max_new_tokens != 0: if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer) sampling_params.normalize(self.tokenizer)
sampling_params.verify() sampling_params.verify()
return sampling_params
if isinstance(obj.image_data, list) and len(obj.image_data) > 0: async def _get_pixel_values(self, image_data):
pixel_values, image_hash, image_size = await self.get_pixel_values( if isinstance(image_data, list) and len(image_data) > 0:
obj.image_data[0] return await self.get_pixel_values(image_data[0])
) elif isinstance(image_data, str):
elif isinstance(obj.image_data, str): return await self.get_pixel_values(image_data)
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data
)
else: else:
pixel_values, image_hash, image_size = None, None, None return None, None, None
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_text=obj.text,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
image_size=image_size,
sampling_params=sampling_params,
return_logprob=obj.return_logprob,
logprob_start_len=obj.logprob_start_len,
top_logprobs_num=obj.top_logprobs_num,
stream=obj.stream,
)
self.send_to_router.send_pyobj(tokenized_obj)
event = asyncio.Event()
state = ReqState([], False, event)
self.rid_to_state[rid] = state
async def _wait_for_response(self, event, state, obj, rid, request):
while True: while True:
try: try:
await asyncio.wait_for(event.wait(), timeout=4) await asyncio.wait_for(event.wait(), timeout=4)
...@@ -192,67 +355,13 @@ class TokenizerManager: ...@@ -192,67 +355,13 @@ class TokenizerManager:
state.out_list = [] state.out_list = []
if state.finished: if state.finished:
del self.rid_to_state[rid] del self.rid_to_state[rid]
yield out yield out
break break
event.clear() event.clear()
yield out yield out
else:
if obj.stream:
raise ValueError("Do not support stream for batch mode.")
if obj.input_ids is None:
bs = len(obj.text)
else:
bs = len(obj.input_ids)
for i in range(bs):
rid = obj.rid[i]
if obj.input_ids is None:
input_text = obj.text[i]
input_ids = self.tokenizer.encode(obj.text[i])
else:
input_text = None
input_ids = obj.input_ids[i]
sampling_params = SamplingParams(**obj.sampling_params[i])
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if obj.image_data[i] is None:
pixel_values, image_hash, image_size = None, None, None
else:
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data[i]
)
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_text=input_text,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
image_size=image_size,
sampling_params=sampling_params,
return_logprob=obj.return_logprob[i],
logprob_start_len=obj.logprob_start_len[i],
top_logprobs_num=obj.top_logprobs_num[i],
stream=obj.stream,
)
self.send_to_router.send_pyobj(tokenized_obj)
event = asyncio.Event()
state = ReqState([], False, event)
self.rid_to_state[rid] = state
output_list = []
for i in range(bs):
rid = obj.rid[i]
state = self.rid_to_state[rid]
async def _wait_for_prefill_response(self, event, state, obj, request, rid):
while True: while True:
try: try:
await asyncio.wait_for(state.event.wait(), timeout=4) await asyncio.wait_for(state.event.wait(), timeout=4)
...@@ -264,19 +373,9 @@ class TokenizerManager: ...@@ -264,19 +373,9 @@ class TokenizerManager:
raise ValueError(f"Abort request {rid}") raise ValueError(f"Abort request {rid}")
continue continue
output_list.append(
self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob[i],
obj.top_logprobs_num[i],
obj.return_text_in_logprobs,
)
)
assert state.finished assert state.finished
del self.rid_to_state[rid] del self.rid_to_state[rid]
yield output_list
def flush_cache(self): def flush_cache(self):
req = FlushCacheReq() req = FlushCacheReq()
self.send_to_router.send_pyobj(req) self.send_to_router.send_pyobj(req)
......
...@@ -95,9 +95,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -95,9 +95,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json() request_json = await raw_request.json()
request = CompletionRequest(**request_json) request = CompletionRequest(**request_json)
if request.n != 1:
return create_error_response("n != 1 is not supported")
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
text=request.prompt, text=request.prompt,
sampling_params={ sampling_params={
...@@ -108,6 +105,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -108,6 +105,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
"presence_penalty": request.presence_penalty, "presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty, "frequency_penalty": request.frequency_penalty,
"regex": request.regex, "regex": request.regex,
"n": request.n,
}, },
return_logprob=request.logprobs is not None and request.logprobs > 0, return_logprob=request.logprobs is not None and request.logprobs > 0,
top_logprobs_num=request.logprobs if request.logprobs is not None else 0, top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
...@@ -202,17 +200,20 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -202,17 +200,20 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
except ValueError as e: except ValueError as e:
return create_error_response(str(e)) return create_error_response(str(e))
ret = ret[0] if isinstance(ret, list) else ret if not isinstance(ret, list):
prompt_tokens = ret["meta_info"]["prompt_tokens"] ret = [ret]
completion_tokens = ret["meta_info"]["completion_tokens"] choices = []
text = ret["text"]
for idx, ret_item in enumerate(ret):
text = ret_item["text"]
if request.echo: if request.echo:
text = request.prompt + text text = request.prompt + text
if request.logprobs: if request.logprobs:
if request.echo: if request.echo:
prefill_token_logprobs = ret["meta_info"]["prefill_token_logprobs"] prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"]
prefill_top_logprobs = ret["meta_info"]["prefill_top_logprobs"] prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"]
else: else:
prefill_token_logprobs = None prefill_token_logprobs = None
prefill_top_logprobs = None prefill_top_logprobs = None
...@@ -220,28 +221,35 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -220,28 +221,35 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
logprobs = to_openai_style_logprobs( logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs, prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs, prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=ret["meta_info"]["decode_token_logprobs"], decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"],
decode_top_logprobs=ret["meta_info"]["decode_top_logprobs"], decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"],
) )
else: else:
logprobs = None logprobs = None
choice_data = CompletionResponseChoice( choice_data = CompletionResponseChoice(
index=0, index=idx,
text=text, text=text,
logprobs=logprobs, logprobs=logprobs,
finish_reason=ret["meta_info"]["finish_reason"], finish_reason=ret_item["meta_info"]["finish_reason"],
) )
choices.append(choice_data)
response = CompletionResponse( response = CompletionResponse(
id=ret["meta_info"]["id"], id=ret[0]["meta_info"]["id"],
model=request.model, model=request.model,
choices=[choice_data], choices=choices,
usage=UsageInfo( usage=UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
completion_tokens=completion_tokens, completion_tokens=sum(
total_tokens=prompt_tokens + completion_tokens, item["meta_info"]["completion_tokens"] for item in ret
),
total_tokens=ret[0]["meta_info"]["prompt_tokens"]
+ sum(item["meta_info"]["completion_tokens"] for item in ret),
), ),
) )
return response return response
...@@ -249,9 +257,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -249,9 +257,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json() request_json = await raw_request.json()
request = ChatCompletionRequest(**request_json) request = ChatCompletionRequest(**request_json)
if request.n != 1:
return create_error_response("n != 1 is not supported")
# Prep the data needed for the underlying GenerateReqInput: # Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string. # - prompt: The full prompt string.
# - stop: Custom stop tokens. # - stop: Custom stop tokens.
...@@ -292,6 +297,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -292,6 +297,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
"presence_penalty": request.presence_penalty, "presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty, "frequency_penalty": request.frequency_penalty,
"regex": request.regex, "regex": request.regex,
"n": request.n,
}, },
stream=request.stream, stream=request.stream,
) )
...@@ -354,23 +360,37 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -354,23 +360,37 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
except ValueError as e: except ValueError as e:
return create_error_response(str(e)) return create_error_response(str(e))
prompt_tokens = ret["meta_info"]["prompt_tokens"] if not isinstance(ret, list):
completion_tokens = ret["meta_info"]["completion_tokens"] ret = [ret]
choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
for idx, ret_item in enumerate(ret):
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
completion_tokens = ret_item["meta_info"]["completion_tokens"]
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=0, index=idx,
message=ChatMessage(role="assistant", content=ret["text"]), message=ChatMessage(role="assistant", content=ret_item["text"]),
finish_reason=ret["meta_info"]["finish_reason"], finish_reason=ret_item["meta_info"]["finish_reason"],
) )
choices.append(choice_data)
total_prompt_tokens = prompt_tokens
total_completion_tokens += completion_tokens
response = ChatCompletionResponse( response = ChatCompletionResponse(
id=ret["meta_info"]["id"], id=ret[0]["meta_info"]["id"],
model=request.model, model=request.model,
choices=[choice_data], choices=choices,
usage=UsageInfo( usage=UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=total_prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=total_completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=total_prompt_tokens + total_completion_tokens,
), ),
) )
return response return response
......
...@@ -20,6 +20,7 @@ class SamplingParams: ...@@ -20,6 +20,7 @@ class SamplingParams:
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
dtype: Optional[str] = None, dtype: Optional[str] = None,
regex: Optional[str] = None, regex: Optional[str] = None,
n: int = 1,
) -> None: ) -> None:
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
...@@ -33,6 +34,7 @@ class SamplingParams: ...@@ -33,6 +34,7 @@ class SamplingParams:
self.spaces_between_special_tokens = spaces_between_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens
self.dtype = dtype self.dtype = dtype
self.regex = regex self.regex = regex
self.n = n
# Process some special cases # Process some special cases
if self.temperature < _SAMPLING_EPS: if self.temperature < _SAMPLING_EPS:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment