"integration-tests/vscode:/vscode.git/clone" did not exist on "06c3d4b1eccbfa9134082d45bc69afa6c43a3e2f"
Unverified Commit 7599bade authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Support embedding input as a list (#1014)

parent 62757db6
...@@ -153,9 +153,7 @@ class TokenizerManager: ...@@ -153,9 +153,7 @@ class TokenizerManager:
async for response in self._handle_single_request(obj, request): async for response in self._handle_single_request(obj, request):
yield response yield response
else: else:
if isinstance(obj, EmbeddingReqInput): if hasattr(obj, "stream") and obj.stream:
raise NotImplementedError("Please send only one prompt in each request")
if obj.stream:
raise ValueError("Do not support stream for batch mode.") raise ValueError("Do not support stream for batch mode.")
async for response in self._handle_batch_request(obj, request): async for response in self._handle_batch_request(obj, request):
...@@ -283,8 +281,11 @@ class TokenizerManager: ...@@ -283,8 +281,11 @@ class TokenizerManager:
await self._wait_for_cache_prefill_response(event, state, obj, rid, request) await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
yield input_ids yield input_ids
async def _handle_batch_request(self, obj: GenerateReqInput, request): async def _handle_batch_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
):
batch_size = obj.batch_size batch_size = obj.batch_size
if self.is_generation:
parallel_sample_num = obj.parallel_sample_num parallel_sample_num = obj.parallel_sample_num
if parallel_sample_num != 1: if parallel_sample_num != 1:
...@@ -301,6 +302,8 @@ class TokenizerManager: ...@@ -301,6 +302,8 @@ class TokenizerManager:
obj.input_ids = input_id_result obj.input_ids = input_id_result
elif input_id_result is not None: elif input_id_result is not None:
obj.input_ids = input_id_result[0] obj.input_ids = input_id_result[0]
else:
parallel_sample_num = 1
# First send out all requests # First send out all requests
for i in range(batch_size): for i in range(batch_size):
...@@ -329,6 +332,8 @@ class TokenizerManager: ...@@ -329,6 +332,8 @@ class TokenizerManager:
input_text = None input_text = None
input_ids = obj.input_ids[i] input_ids = obj.input_ids[i]
sampling_params = self._get_sampling_params(obj.sampling_params[index]) sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation:
pixel_values, image_hash, image_size = await self._get_pixel_values( pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data[index] obj.image_data[index]
) )
...@@ -346,11 +351,19 @@ class TokenizerManager: ...@@ -346,11 +351,19 @@ class TokenizerManager:
obj.top_logprobs_num[index], obj.top_logprobs_num[index],
obj.stream, obj.stream,
) )
else:
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_router.send_pyobj(tokenized_obj) self.send_to_router.send_pyobj(tokenized_obj)
event = asyncio.Event() event = asyncio.Event()
state = ReqState([], False, event) state = ReqState([], False, event)
self.rid_to_state[rid] = state self.rid_to_state[rid] = state
# Then wait for all responses # Then wait for all responses
output_list = [] output_list = []
for i in range(batch_size): for i in range(batch_size):
...@@ -373,6 +386,7 @@ class TokenizerManager: ...@@ -373,6 +386,7 @@ class TokenizerManager:
self.abort_request(rid) self.abort_request(rid)
raise ValueError(f"Abort request {rid}") raise ValueError(f"Abort request {rid}")
continue continue
if self.is_generation:
output_list.append( output_list.append(
self.convert_logprob_style( self.convert_logprob_style(
state.out_list[-1], state.out_list[-1],
...@@ -381,6 +395,8 @@ class TokenizerManager: ...@@ -381,6 +395,8 @@ class TokenizerManager:
obj.return_text_in_logprobs, obj.return_text_in_logprobs,
) )
) )
else:
output_list.append(state.out_list[-1])
assert state.finished assert state.finished
del self.rid_to_state[rid] del self.rid_to_state[rid]
yield output_list yield output_list
......
...@@ -219,11 +219,9 @@ class SRTRunner: ...@@ -219,11 +219,9 @@ class SRTRunner:
output_strs=output_strs, top_input_logprobs=top_input_logprobs output_strs=output_strs, top_input_logprobs=top_input_logprobs
) )
else: else:
logits = [] response = self.runtime.encode(prompts)
for prompt in prompts:
response = self.runtime.encode(prompt)
response = json.loads(response) response = json.loads(response)
logits.append(response["embedding"]) logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits) return ModelOutput(embed_logits=logits)
def __enter__(self): def __enter__(self):
......
...@@ -38,8 +38,9 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -38,8 +38,9 @@ class TestOpenAIServer(unittest.TestCase):
num_prompt_tokens = len(self.tokenizer.encode(prompt)) num_prompt_tokens = len(self.tokenizer.encode(prompt))
if use_list_input: if use_list_input:
prompt_arg = [prompt_input, prompt_input] prompt_arg = [prompt_input] * 2
num_prompts = len(prompt_arg) num_prompts = len(prompt_arg)
num_prompt_tokens *= num_prompts
else: else:
prompt_arg = prompt_input prompt_arg = prompt_input
num_prompts = 1 num_prompts = 1
...@@ -70,7 +71,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -70,7 +71,7 @@ class TestOpenAIServer(unittest.TestCase):
def test_embedding(self): def test_embedding(self):
# TODO the fields of encoding_format, dimensions, user are skipped # TODO the fields of encoding_format, dimensions, user are skipped
# TODO support use_list_input # TODO support use_list_input
for use_list_input in [False]: for use_list_input in [False, True]:
for token_input in [False, True]: for token_input in [False, True]:
self.run_embedding(use_list_input, token_input) self.run_embedding(use_list_input, token_input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment