Unverified Commit bb3a3b66 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Support Faster JSON decoding for llava (#137)

When sending fast-forwarded reqs to model_rpc, re-calculate `pad_input_ids`
parent 45d6592d
......@@ -31,6 +31,7 @@ class Req:
self.pixel_values = None
self.image_size = None
self.image_offset = 0
self.pad_value = None
self.sampling_params = None
self.return_logprob = False
......@@ -58,7 +59,7 @@ class Req:
def max_new_tokens(self):
return self.sampling_params.max_new_tokens
def tokenize_fast_forward(self, fast_forward_str, next_state):
def fast_forward_and_retokenize(self, fast_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids)
# FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space.
......@@ -75,9 +76,14 @@ class Req:
+ fast_forward_str
)
new_input_ids = self.tokenizer.encode(new_input_string)
fast_forward_tokens_len = (
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
)
if self.pixel_values is not None:
# NOTE: This is a hack because the old input_ids contains the image padding
fast_forward_tokens_len = len(self.tokenizer.encode(fast_forward_str))
else:
fast_forward_tokens_len = (
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
)
# print("=" * 100)
# print(f"Catch fast forward:\n{fast_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
......@@ -351,7 +357,7 @@ class Batch:
self.tree_cache.dec_ref_counter(req.last_node)
# fast forward
req.tokenize_fast_forward(fast_forward_str, next_state)
req.fast_forward_and_retokenize(fast_forward_str, next_state)
fast_forward_reqs.append(req)
filter_indices.remove(i)
......
......@@ -83,7 +83,9 @@ class ModelRpcServer(rpyc.Service):
self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max(
self.model_config.context_len,
self.max_total_num_token // 6 if server_args.max_prefill_num_token is None else server_args.max_prefill_num_token,
self.max_total_num_token // 6
if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token,
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
......@@ -233,7 +235,7 @@ class ModelRpcServer(rpyc.Service):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
pad_value = [
req.pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
......@@ -241,7 +243,7 @@ class ModelRpcServer(rpyc.Service):
]
req.image_size = recv_req.image_size
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
req.input_ids, pad_value, req.pixel_values.shape, req.image_size
req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
)
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob
......@@ -438,6 +440,20 @@ class ModelRpcServer(rpyc.Service):
if not self.no_regex_fast_forward:
# check for fast forward
fast_forward_reqs = batch.check_for_fast_forward()
# check for image fast forward
for req in fast_forward_reqs:
if req.pixel_values is not None:
(
req.input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.input_ids,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
self.forward_queue.extend(fast_forward_reqs)
if batch.is_empty():
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment