Unverified Commit 0736b270 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Improve the code style in TokenizerManager (#767)

parent 3fdab919
...@@ -376,7 +376,7 @@ class Batch: ...@@ -376,7 +376,7 @@ class Batch:
logit_bias = torch.zeros( logit_bias = torch.zeros(
(bs, vocab_size), dtype=torch.float32, device=device (bs, vocab_size), dtype=torch.float32, device=device
) )
logit_bias[i] = int_token_logit_bias logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
# Set fields # Set fields
self.input_ids = torch.tensor( self.input_ids = torch.tensor(
......
...@@ -133,24 +133,10 @@ class TokenizerManager: ...@@ -133,24 +133,10 @@ class TokenizerManager:
async for response in self._handle_batch_request(obj, request): async for response in self._handle_batch_request(obj, request):
yield response yield response
async def _handle_single_request(self, obj, request, index=None, is_prefill=False): async def _handle_single_request(
if is_prefill: self, obj, request, index=None, is_cache_for_prefill=False
if isinstance(obj.text, list): ):
input_text = obj.text[index] if not is_cache_for_prefill:
rid = obj.rid[index]
else:
input_text = obj.text
rid = obj.rid[0]
input_ids = self.tokenizer.encode(input_text)
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data[0]
)
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
else:
rid = obj.rid if index is None else obj.rid[index] rid = obj.rid if index is None else obj.rid[index]
input_text = obj.text if index is None else obj.text[index] input_text = obj.text if index is None else obj.text[index]
input_ids = ( input_ids = (
...@@ -177,6 +163,22 @@ class TokenizerManager: ...@@ -177,6 +163,22 @@ class TokenizerManager:
top_logprobs_num = ( top_logprobs_num = (
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index] obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
) )
else:
if isinstance(obj.text, list):
input_text = obj.text[index]
rid = obj.rid[index]
else:
input_text = obj.text
rid = obj.rid[0]
input_ids = self.tokenizer.encode(input_text)
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data[0]
)
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
rid, rid,
...@@ -196,26 +198,26 @@ class TokenizerManager: ...@@ -196,26 +198,26 @@ class TokenizerManager:
event = asyncio.Event() event = asyncio.Event()
state = ReqState([], False, event) state = ReqState([], False, event)
self.rid_to_state[rid] = state self.rid_to_state[rid] = state
if is_prefill: if not is_cache_for_prefill:
await self._wait_for_prefill_response(event, state, obj, request, rid)
yield input_ids
else:
async for response in self._wait_for_response( async for response in self._wait_for_response(
event, state, obj, rid, request event, state, obj, rid, request
): ):
yield response yield response
else:
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
yield input_ids
async def _handle_batch_request(self, obj, request): async def _handle_batch_request(self, obj: GenerateReqInput, request):
batch_size = obj.batch_size batch_size = obj.batch_size
parallel_sample_num = obj.sampling_params[0].get("n", 1) parallel_sample_num = obj.sampling_params[0].get("n", 1)
if parallel_sample_num != 1: if parallel_sample_num != 1:
## send prefill requests # Send prefill requests to cache the common input
parallel_sample_num += 1 parallel_sample_num += 1
input_id_result = [] if obj.input_ids is None else None input_id_result = [] if obj.input_ids is None else None
for i in range(batch_size): for i in range(batch_size):
async for input_id in self._handle_single_request( async for input_id in self._handle_single_request(
obj, request, index=i, is_prefill=True obj, request, index=i, is_cache_for_prefill=True
): ):
if input_id_result is not None: if input_id_result is not None:
input_id_result.append(input_id) input_id_result.append(input_id)
...@@ -224,6 +226,7 @@ class TokenizerManager: ...@@ -224,6 +226,7 @@ class TokenizerManager:
obj.input_ids = input_id_result obj.input_ids = input_id_result
elif input_id_result is not None: elif input_id_result is not None:
obj.input_ids = input_id_result[0] obj.input_ids = input_id_result[0]
# First send out all requests # First send out all requests
for i in range(batch_size): for i in range(batch_size):
for j in range(parallel_sample_num): for j in range(parallel_sample_num):
...@@ -308,17 +311,15 @@ class TokenizerManager: ...@@ -308,17 +311,15 @@ class TokenizerManager:
yield output_list yield output_list
def _validate_input_length(self, input_ids): def _validate_input_length(self, input_ids: List[int]):
if len(input_ids) >= self.context_len: if len(input_ids) >= self.context_len:
raise ValueError( raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the " f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)." f"model's context length ({self.context_len} tokens)."
) )
def _get_sampling_params(self, sampling_params_data, max_new_tokens=None): def _get_sampling_params(self, sampling_params_data: dict):
sampling_params = SamplingParams(**sampling_params_data) sampling_params = SamplingParams(**sampling_params_data)
if max_new_tokens is not None:
sampling_params.max_new_tokens = max_new_tokens
if sampling_params.max_new_tokens != 0: if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer) sampling_params.normalize(self.tokenizer)
sampling_params.verify() sampling_params.verify()
...@@ -332,7 +333,14 @@ class TokenizerManager: ...@@ -332,7 +333,14 @@ class TokenizerManager:
else: else:
return None, None, None return None, None, None
async def _wait_for_response(self, event, state, obj, rid, request): async def _wait_for_response(
self,
event: asyncio.Event,
state: ReqState,
obj: GenerateReqInput,
rid: str,
request,
):
while True: while True:
try: try:
await asyncio.wait_for(event.wait(), timeout=4) await asyncio.wait_for(event.wait(), timeout=4)
...@@ -361,7 +369,14 @@ class TokenizerManager: ...@@ -361,7 +369,14 @@ class TokenizerManager:
event.clear() event.clear()
yield out yield out
async def _wait_for_prefill_response(self, event, state, obj, request, rid): async def _wait_for_cache_prefill_response(
self,
event: asyncio.Event,
state: ReqState,
obj: GenerateReqInput,
rid: str,
request,
):
while True: while True:
try: try:
await asyncio.wait_for(state.event.wait(), timeout=4) await asyncio.wait_for(state.event.wait(), timeout=4)
...@@ -380,7 +395,7 @@ class TokenizerManager: ...@@ -380,7 +395,7 @@ class TokenizerManager:
req = FlushCacheReq() req = FlushCacheReq()
self.send_to_router.send_pyobj(req) self.send_to_router.send_pyobj(req)
def abort_request(self, rid): def abort_request(self, rid: str):
if rid not in self.rid_to_state: if rid not in self.rid_to_state:
return return
del self.rid_to_state[rid] del self.rid_to_state[rid]
...@@ -426,7 +441,11 @@ class TokenizerManager: ...@@ -426,7 +441,11 @@ class TokenizerManager:
state.event.set() state.event.set()
def convert_logprob_style( def convert_logprob_style(
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs self,
ret: dict,
return_logprob: bool,
top_logprobs_num: int,
return_text_in_logprobs: bool,
): ):
if return_logprob: if return_logprob:
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens( ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
...@@ -450,7 +469,7 @@ class TokenizerManager: ...@@ -450,7 +469,7 @@ class TokenizerManager:
) )
return ret return ret
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text): def detokenize_logprob_tokens(self, token_logprobs, decode_to_text: bool):
if not decode_to_text: if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs] return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
...@@ -461,7 +480,7 @@ class TokenizerManager: ...@@ -461,7 +480,7 @@ class TokenizerManager:
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts) for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
] ]
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text): def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
for i, t in enumerate(top_logprobs): for i, t in enumerate(top_logprobs):
if t: if t:
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text) top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
......
...@@ -118,7 +118,11 @@ def test_decode_json_regex(): ...@@ -118,7 +118,11 @@ def test_decode_json_regex():
s += "}" s += "}"
ret = decode_json.run() ret = decode_json.run()
js_obj = json.loads(ret["json_output"]) try:
js_obj = json.loads(ret["json_output"])
except json.decoder.JSONDecodeError:
print(ret["json_output"])
raise
assert isinstance(js_obj["name"], str) assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int) assert isinstance(js_obj["population"], int)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment